1
0

Merge commit '22db45bd4' into anoa/dinsic_release_1_31_0

This commit is contained in:
Andrew Morgan
2021-04-23 17:26:31 +01:00
53 changed files with 668 additions and 255 deletions
+9
View File
@@ -1,9 +1,18 @@
Removal warning
---------------
Note that this release deprecates the ability for appservices to call `POST /_matrix/client/r0/register` without the body parameter `type`. Appservice developers should use a `type` value of `m.login.application_service` as per the spec. In future releases, calling this endpoint with an access token but
without a valid type will fail.
Synapse 1.29.0 (2021-03-08)
===========================
Note that synapse now expects an `X-Forwarded-Proto` header when used with a reverse proxy. Please see [UPGRADE.rst](UPGRADE.rst#upgrading-to-v1290) for more details on this change.
No significant changes.
Synapse 1.29.0rc1 (2021-03-04)
==============================
+3 -2
View File
@@ -183,8 +183,9 @@ Using a reverse proxy with Synapse
It is recommended to put a reverse proxy such as
`nginx <https://nginx.org/en/docs/http/ngx_http_proxy_module.html>`_,
`Apache <https://httpd.apache.org/docs/current/mod/mod_proxy_http.html>`_,
`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_ or
`HAProxy <https://www.haproxy.org/>`_ in front of Synapse. One advantage of
`Caddy <https://caddyserver.com/docs/quick-starts/reverse-proxy>`_,
`HAProxy <https://www.haproxy.org/>`_ or
`relayd <https://man.openbsd.org/relayd.8>`_ in front of Synapse. One advantage of
doing so is that it means that you can expose the default https port (443) to
Matrix clients without needing to run Synapse with root privileges.
+7
View File
@@ -124,6 +124,13 @@ This version changes the URI used for callbacks from OAuth2 and SAML2 identity p
need to add ``[synapse public baseurl]/_synapse/client/saml2/authn_response`` as a permitted
"ACS location" (also known as "allowed callback URLs") at the identity provider.
The "Issuer" in the "AuthnRequest" to the SAML2 identity provider is also updated to
``[synapse public baseurl]/_synapse/client/saml2/metadata.xml``. If your SAML2 identity
provider uses this property to validate or otherwise identify Synapse, its configuration
will need to be updated to use the new URL. Alternatively you could create a new, separate
"EntityDescriptor" in your SAML2 identity provider with the new URLs and leave the URLs in
the existing "EntityDescriptor" as they were.
Changes to HTML templates
-------------------------
+1
View File
@@ -0,0 +1 @@
Add tests to ResponseCache.
+1
View File
@@ -0,0 +1 @@
Add relayd entry to reverse proxy example configurations.
+1
View File
@@ -0,0 +1 @@
Add prometheus metrics for number of users successfully registering and logging in.
+1
View File
@@ -0,0 +1 @@
Add prometheus metrics for number of users successfully registering and logging in.
+1
View File
@@ -0,0 +1 @@
Add type hints to purge room and server notice admin API.
+1
View File
@@ -0,0 +1 @@
Add extra logging to ObservableDeferred when callbacks throw exceptions.
+1
View File
@@ -0,0 +1 @@
Fix incorrect type hints.
+1
View File
@@ -0,0 +1 @@
Add `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time` prometheus metrics, which monitor federation delays by reporting the timestamps of messages sent and received to a set of remote servers.
+1
View File
@@ -0,0 +1 @@
The `synapse_federation_last_sent_pdu_age` and `synapse_federation_last_received_pdu_age` prometheus metrics have been removed. They are replaced by `synapse_federation_last_sent_pdu_time` and `synapse_federation_last_received_pdu_time`.
+1
View File
@@ -0,0 +1 @@
Add an additional test for purging a room.
+1
View File
@@ -0,0 +1 @@
Improve the SAML2 upgrade notes for 1.27.0.
+1
View File
@@ -0,0 +1 @@
Registering an Application Service user without using the `m.login.application_service` login type will be unsupported in an upcoming Synapse release.
+1
View File
@@ -0,0 +1 @@
Fix spurious errors reported by the `config-lint.sh` script.
+49 -2
View File
@@ -3,8 +3,9 @@
It is recommended to put a reverse proxy such as
[nginx](https://nginx.org/en/docs/http/ngx_http_proxy_module.html),
[Apache](https://httpd.apache.org/docs/current/mod/mod_proxy_http.html),
[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy) or
[HAProxy](https://www.haproxy.org/) in front of Synapse. One advantage
[Caddy](https://caddyserver.com/docs/quick-starts/reverse-proxy),
[HAProxy](https://www.haproxy.org/) or
[relayd](https://man.openbsd.org/relayd.8) in front of Synapse. One advantage
of doing so is that it means that you can expose the default https port
(443) to Matrix clients without needing to run Synapse with root
privileges.
@@ -162,6 +163,52 @@ backend matrix
server matrix 127.0.0.1:8008
```
### Relayd
```
table <webserver> { 127.0.0.1 }
table <matrixserver> { 127.0.0.1 }
http protocol "https" {
tls { no tlsv1.0, ciphers "HIGH" }
tls keypair "example.com"
match header set "X-Forwarded-For" value "$REMOTE_ADDR"
match header set "X-Forwarded-Proto" value "https"
# set CORS header for .well-known/matrix/server, .well-known/matrix/client
# httpd does not support setting headers, so do it here
match request path "/.well-known/matrix/*" tag "matrix-cors"
match response tagged "matrix-cors" header set "Access-Control-Allow-Origin" value "*"
pass quick path "/_matrix/*" forward to <matrixserver>
pass quick path "/_synapse/client/*" forward to <matrixserver>
# pass on non-matrix traffic to webserver
pass forward to <webserver>
}
relay "https_traffic" {
listen on egress port 443 tls
protocol "https"
forward to <matrixserver> port 8008 check tcp
forward to <webserver> port 8080 check tcp
}
http protocol "matrix" {
tls { no tlsv1.0, ciphers "HIGH" }
tls keypair "example.com"
block
pass quick path "/_matrix/*" forward to <matrixserver>
pass quick path "/_synapse/client/*" forward to <matrixserver>
}
relay "matrix_federation" {
listen on egress port 8448 tls
protocol "matrix"
forward to <matrixserver> port 8008 check tcp
}
```
## Homeserver Configuration
You will also want to set `bind_addresses: ['127.0.0.1']` and
+1
View File
@@ -69,6 +69,7 @@ files =
synapse/util/async_helpers.py,
synapse/util/caches,
synapse/util/metrics.py,
synapse/util/macaroons.py,
synapse/util/stringutils.py,
tests/replication,
tests/test_utils,
+7 -2
View File
@@ -2,9 +2,14 @@
# Find linting errors in Synapse's default config file.
# Exits with 0 if there are no problems, or another code otherwise.
# cd to the root of the repository
cd `dirname $0`/..
# Restore backup of sample config upon script exit
trap "mv docs/sample_config.yaml.bak docs/sample_config.yaml" EXIT
# Fix non-lowercase true/false values
sed -i.bak -E "s/: +True/: true/g; s/: +False/: false/g;" docs/sample_config.yaml
rm docs/sample_config.yaml.bak
# Check if anything changed
git diff --exit-code docs/sample_config.yaml
diff docs/sample_config.yaml docs/sample_config.yaml.bak
+9 -32
View File
@@ -39,6 +39,7 @@ from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
from synapse.util.caches.lrucache import LruCache
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -413,7 +414,7 @@ class Auth:
raise _InvalidMacaroonException()
try:
user_id = self.get_user_id_from_macaroon(macaroon)
user_id = get_value_from_macaroon(macaroon, "user_id")
guest = False
for caveat in macaroon.caveats:
@@ -421,7 +422,12 @@ class Auth:
guest = True
self.validate_macaroon(macaroon, rights, user_id=user_id)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
except (
pymacaroons.exceptions.MacaroonException,
KeyError,
TypeError,
ValueError,
):
raise InvalidClientTokenError("Invalid macaroon passed.")
if rights == "access":
@@ -429,27 +435,6 @@ class Auth:
return user_id, guest
def get_user_id_from_macaroon(self, macaroon):
"""Retrieve the user_id given by the caveats on the macaroon.
Does *not* validate the macaroon.
Args:
macaroon (pymacaroons.Macaroon): The macaroon to validate
Returns:
(str) user id
Raises:
InvalidClientCredentialsError if there is no user_id caveat in the
macaroon
"""
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
return caveat.caveat_id[len(user_prefix) :]
raise InvalidClientTokenError("No user caveat in macaroon")
def validate_macaroon(self, macaroon, type_string, user_id):
"""
validate that a Macaroon is understood by and was signed by this server.
@@ -470,21 +455,13 @@ class Auth:
v.satisfy_exact("type = " + type_string)
v.satisfy_exact("user_id = %s" % user_id)
v.satisfy_exact("guest = true")
v.satisfy_general(self._verify_expiry)
satisfy_expiry(v, self.clock.time_msec)
# access_tokens include a nonce for uniqueness: any value is acceptable
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.verify(macaroon, self._macaroon_secret_key)
def _verify_expiry(self, caveat):
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self.hs.get_clock().time_msec()
return now < expiry
def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
+1 -1
View File
@@ -90,7 +90,7 @@ class ApplicationServiceApi(SimpleHttpClient):
self.clock = hs.get_clock()
self.protocol_meta_cache = ResponseCache(
hs, "as_protocol_meta", timeout_ms=HOUR_IN_MS
hs.get_clock(), "as_protocol_meta", timeout_ms=HOUR_IN_MS
) # type: ResponseCache[Tuple[str, str]]
async def query_user(self, service, user_id):
+1 -2
View File
@@ -847,8 +847,7 @@ class ServerConfig(Config):
# Whether to require authentication to retrieve profile data (avatars,
# display names) of other users through the client API. Defaults to
# 'false'. Note that profile data is also available via the federation
# API, so this setting is of limited value if federation is enabled on
# the server.
# API, unless allow_profile_lookup_over_federation is set to false.
#
#require_auth_for_profile_requests: true
+12 -11
View File
@@ -22,6 +22,7 @@ from typing import (
Awaitable,
Callable,
Dict,
Iterable,
List,
Optional,
Tuple,
@@ -91,16 +92,15 @@ pdu_process_time = Histogram(
"Time taken to process an event",
)
last_pdu_age_metric = Gauge(
"synapse_federation_last_received_pdu_age",
"The age (in seconds) of the last PDU successfully received from the given domain",
last_pdu_ts_metric = Gauge(
"synapse_federation_last_received_pdu_time",
"The timestamp of the last PDU which was successfully received from the given domain",
labelnames=("server_name",),
)
class FederationServer(FederationBase):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.auth = hs.get_auth()
@@ -120,7 +120,7 @@ class FederationServer(FederationBase):
# We cache results for transaction with the same ID
self._transaction_resp_cache = ResponseCache(
hs, "fed_txn_handler", timeout_ms=30000
hs.get_clock(), "fed_txn_handler", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self.transaction_actions = TransactionActions(self.store)
@@ -130,10 +130,10 @@ class FederationServer(FederationBase):
# We cache responses to state queries, as they take a while and often
# come in waves.
self._state_resp_cache = ResponseCache(
hs, "state_resp", timeout_ms=30000
hs.get_clock(), "state_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self._state_ids_resp_cache = ResponseCache(
hs, "state_ids_resp", timeout_ms=30000
hs.get_clock(), "state_ids_resp", timeout_ms=30000
) # type: ResponseCache[Tuple[str, str]]
self._federation_metrics_domains = (
@@ -370,8 +370,7 @@ class FederationServer(FederationBase):
)
if newest_pdu_ts and origin in self._federation_metrics_domains:
newest_pdu_age = self._clock.time_msec() - newest_pdu_ts
last_pdu_age_metric.labels(server_name=origin).set(newest_pdu_age / 1000)
last_pdu_ts_metric.labels(server_name=origin).set(newest_pdu_ts / 1000)
return pdu_results
@@ -456,7 +455,9 @@ class FederationServer(FederationBase):
self, room_id: str, event_id: str
) -> Dict[str, list]:
if event_id:
pdus = await self.handler.get_state_for_pdu(room_id, event_id)
pdus = await self.handler.get_state_for_pdu(
room_id, event_id
) # type: Iterable[EventBase]
else:
pdus = (await self.state.get_current_state(room_id)).values()
@@ -36,9 +36,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
last_pdu_age_metric = Gauge(
"synapse_federation_last_sent_pdu_age",
"The age (in seconds) of the last PDU successfully sent to the given domain",
last_pdu_ts_metric = Gauge(
"synapse_federation_last_sent_pdu_time",
"The timestamp of the last PDU which was successfully sent to the given domain",
labelnames=("server_name",),
)
@@ -187,9 +187,8 @@ class TransactionManager:
if success and pdus and destination in self._federation_metrics_domains:
last_pdu = pdus[-1]
last_pdu_age = self.clock.time_msec() - last_pdu.origin_server_ts
last_pdu_age_metric.labels(server_name=destination).set(
last_pdu_age / 1000
last_pdu_ts_metric.labels(server_name=destination).set(
last_pdu.origin_server_ts / 1000
)
set_tag(tags.ERROR, not success)
+3 -1
View File
@@ -73,7 +73,9 @@ class AcmeHandler:
"Listening for ACME requests on %s:%i", host, self.hs.config.acme_port
)
try:
self.reactor.listenTCP(self.hs.config.acme_port, srv, interface=host)
self.reactor.listenTCP(
self.hs.config.acme_port, srv, backlog=50, interface=host
)
except twisted.internet.error.CannotListenError as e:
check_bind_error(e, host, bind_addresses)
+58 -10
View File
@@ -65,6 +65,7 @@ from synapse.storage.roommember import ProfileInfo
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
@@ -170,6 +171,16 @@ class SsoLoginExtraAttributes:
extra_attributes = attr.ib(type=JsonDict)
@attr.s(slots=True, frozen=True)
class LoginTokenAttributes:
"""Data we store in a short-term login token"""
user_id = attr.ib(type=str)
# the SSO Identity Provider that the user authenticated with, to get this token
auth_provider_id = attr.ib(type=str)
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
@@ -1164,18 +1175,16 @@ class AuthHandler(BaseHandler):
return None
return user_id
async def validate_short_term_login_token_and_get_user_id(self, login_token: str):
auth_api = self.hs.get_auth()
user_id = None
async def validate_short_term_login_token(
self, login_token: str
) -> LoginTokenAttributes:
try:
macaroon = pymacaroons.Macaroon.deserialize(login_token)
user_id = auth_api.get_user_id_from_macaroon(macaroon)
auth_api.validate_macaroon(macaroon, "login", user_id)
res = self.macaroon_gen.verify_short_term_login_token(login_token)
except Exception:
raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN)
await self.auth.check_auth_blocking(user_id)
return user_id
await self.auth.check_auth_blocking(res.user_id)
return res
async def delete_access_token(self, access_token: str):
"""Invalidate a single access token
@@ -1397,6 +1406,7 @@ class AuthHandler(BaseHandler):
async def complete_sso_login(
self,
registered_user_id: str,
auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
@@ -1406,6 +1416,9 @@ class AuthHandler(BaseHandler):
Args:
registered_user_id: The registered user ID to complete SSO login for.
auth_provider_id: The id of the SSO Identity provider that was used for
login. This will be stored in the login token for future tracking in
prometheus metrics.
request: The request to complete.
client_redirect_url: The URL to which to redirect the user at the end of the
process.
@@ -1427,6 +1440,7 @@ class AuthHandler(BaseHandler):
self._complete_sso_login(
registered_user_id,
auth_provider_id,
request,
client_redirect_url,
extra_attributes,
@@ -1437,6 +1451,7 @@ class AuthHandler(BaseHandler):
def _complete_sso_login(
self,
registered_user_id: str,
auth_provider_id: str,
request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
@@ -1463,7 +1478,7 @@ class AuthHandler(BaseHandler):
# Create a login token
login_token = self.macaroon_gen.generate_short_term_login_token(
registered_user_id
registered_user_id, auth_provider_id=auth_provider_id
)
# Append the login token to the original redirect URL (i.e. with its query
@@ -1569,15 +1584,48 @@ class MacaroonGenerator:
return macaroon.serialize()
def generate_short_term_login_token(
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
self,
user_id: str,
auth_provider_id: str,
duration_in_ms: int = (2 * 60 * 1000),
) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = login")
now = self.hs.get_clock().time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
macaroon.add_first_party_caveat("auth_provider_id = %s" % (auth_provider_id,))
return macaroon.serialize()
def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes:
"""Verify a short-term-login macaroon
Checks that the given token is a valid, unexpired short-term-login token
minted by this server.
Args:
token: the login token to verify
Returns:
the user_id that this token is valid for
Raises:
MacaroonVerificationFailedException if the verification failed
"""
macaroon = pymacaroons.Macaroon.deserialize(token)
user_id = get_value_from_macaroon(macaroon, "user_id")
auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id")
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = login")
v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(lambda c: c.startswith("auth_provider_id = "))
satisfy_expiry(v, self.hs.get_clock().time_msec)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
return LoginTokenAttributes(user_id=user_id, auth_provider_id=auth_provider_id)
def generate_delete_pusher_token(self, user_id: str) -> str:
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = delete_pusher")
+1 -1
View File
@@ -48,7 +48,7 @@ class InitialSyncHandler(BaseHandler):
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = ResponseCache(
hs, "initial_sync_cache"
hs.get_clock(), "initial_sync_cache"
) # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
+14 -51
View File
@@ -42,6 +42,7 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -211,7 +212,7 @@ class OidcHandler:
session_data = self._token_generator.verify_oidc_session_token(
session, state
)
except (MacaroonDeserializationException, ValueError) as e:
except (MacaroonDeserializationException, KeyError) as e:
logger.exception("Invalid session for OIDC callback")
self._sso_handler.render_error(request, "invalid_session", str(e))
return
@@ -745,7 +746,7 @@ class OidcProvider:
idp_id=self.idp_id,
nonce=nonce,
client_redirect_url=client_redirect_url.decode(),
ui_auth_session_id=ui_auth_session_id,
ui_auth_session_id=ui_auth_session_id or "",
),
)
@@ -1020,10 +1021,9 @@ class OidcSessionTokenGenerator:
macaroon.add_first_party_caveat(
"client_redirect_url = %s" % (session_data.client_redirect_url,)
)
if session_data.ui_auth_session_id:
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
)
macaroon.add_first_party_caveat(
"ui_auth_session_id = %s" % (session_data.ui_auth_session_id,)
)
now = self._clock.time_msec()
expiry = now + duration_in_ms
macaroon.add_first_party_caveat("time < %d" % (expiry,))
@@ -1046,7 +1046,7 @@ class OidcSessionTokenGenerator:
The data extracted from the session cookie
Raises:
ValueError if an expected caveat is missing from the macaroon.
KeyError if an expected caveat is missing from the macaroon.
"""
macaroon = pymacaroons.Macaroon.deserialize(session)
@@ -1057,26 +1057,16 @@ class OidcSessionTokenGenerator:
v.satisfy_general(lambda c: c.startswith("nonce = "))
v.satisfy_general(lambda c: c.startswith("idp_id = "))
v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
# Sometimes there's a UI auth session ID, it seems to be OK to attempt
# to always satisfy this.
v.satisfy_general(lambda c: c.startswith("ui_auth_session_id = "))
v.satisfy_general(self._verify_expiry)
satisfy_expiry(v, self._clock.time_msec)
v.verify(macaroon, self._macaroon_secret_key)
# Extract the session data from the token.
nonce = self._get_value_from_macaroon(macaroon, "nonce")
idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
client_redirect_url = self._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
try:
ui_auth_session_id = self._get_value_from_macaroon(
macaroon, "ui_auth_session_id"
) # type: Optional[str]
except ValueError:
ui_auth_session_id = None
nonce = get_value_from_macaroon(macaroon, "nonce")
idp_id = get_value_from_macaroon(macaroon, "idp_id")
client_redirect_url = get_value_from_macaroon(macaroon, "client_redirect_url")
ui_auth_session_id = get_value_from_macaroon(macaroon, "ui_auth_session_id")
return OidcSessionData(
nonce=nonce,
idp_id=idp_id,
@@ -1084,33 +1074,6 @@ class OidcSessionTokenGenerator:
ui_auth_session_id=ui_auth_session_id,
)
def _get_value_from_macaroon(self, macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
Args:
macaroon: the token
key: the key of the caveat to extract
Returns:
The extracted value
Raises:
ValueError: if the caveat was not in the macaroon
"""
prefix = key + " = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(prefix):
return caveat.caveat_id[len(prefix) :]
raise ValueError("No %s caveat in macaroon" % (key,))
def _verify_expiry(self, caveat: str) -> bool:
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
now = self._clock.time_msec()
return now < expiry
@attr.s(frozen=True, slots=True)
class OidcSessionData:
@@ -1125,8 +1088,8 @@ class OidcSessionData:
# The URL the client gave when it initiated the flow. ("" if this is a UI Auth)
client_redirect_url = attr.ib(type=str)
# The session ID of the ongoing UI Auth (None if this is a login)
ui_auth_session_id = attr.ib(type=Optional[str], default=None)
# The session ID of the ongoing UI Auth ("" if this is a login)
ui_auth_session_id = attr.ib(type=str)
UserAttributeDict = TypedDict(
+33 -2
View File
@@ -18,6 +18,8 @@
import logging
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
from prometheus_client import Counter
from synapse import types
from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
@@ -41,6 +43,19 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
registration_counter = Counter(
"synapse_user_registrations_total",
"Number of new users registered (since restart)",
["guest", "shadow_banned", "auth_provider"],
)
login_counter = Counter(
"synapse_user_logins_total",
"Number of user logins (since restart)",
["guest", "auth_provider"],
)
class RegistrationHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
@@ -171,6 +186,7 @@ class RegistrationHandler(BaseHandler):
bind_emails: Iterable[str] = [],
by_admin: bool = False,
user_agent_ips: Optional[List[Tuple[str, str]]] = None,
auth_provider_id: Optional[str] = None,
) -> str:
"""Registers a new client on the server.
@@ -196,8 +212,10 @@ class RegistrationHandler(BaseHandler):
admin api, otherwise False.
user_agent_ips: Tuples of IP addresses and user-agents used
during the registration process.
auth_provider_id: The SSO IdP the user used, if any (just used for the
prometheus metrics).
Returns:
The registere user_id.
The registered user_id.
Raises:
SynapseError if there was a problem registering.
"""
@@ -304,6 +322,12 @@ class RegistrationHandler(BaseHandler):
# if user id is taken, just generate another
fail_count += 1
registration_counter.labels(
guest=make_guest,
shadow_banned=shadow_banned,
auth_provider=(auth_provider_id or ""),
).inc()
if not self.hs.config.user_consent_at_registration:
if not self.hs.config.auto_join_rooms_for_guests and make_guest:
logger.info(
@@ -718,6 +742,7 @@ class RegistrationHandler(BaseHandler):
initial_display_name: Optional[str],
is_guest: bool = False,
is_appservice_ghost: bool = False,
auth_provider_id: Optional[str] = None,
) -> Tuple[str, str]:
"""Register a device for a user and generate an access token.
@@ -728,7 +753,8 @@ class RegistrationHandler(BaseHandler):
device_id: The device ID to check, or None to generate a new one.
initial_display_name: An optional display name for the device.
is_guest: Whether this is a guest account
auth_provider_id: The SSO IdP the user used, if any (just used for the
prometheus metrics).
Returns:
Tuple of device ID and access token
"""
@@ -767,6 +793,11 @@ class RegistrationHandler(BaseHandler):
is_appservice_ghost=is_appservice_ghost,
)
login_counter.labels(
guest=is_guest,
auth_provider=(auth_provider_id or ""),
).inc()
return (registered_device_id, access_token)
async def post_registration_actions(
+1 -1
View File
@@ -121,7 +121,7 @@ class RoomCreationHandler(BaseHandler):
# succession, only process the first attempt and return its result to
# subsequent requests
self._upgrade_response_cache = ResponseCache(
hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
) # type: ResponseCache[Tuple[str, str]]
self._server_notices_mxid = hs.config.server_notices_mxid
+2 -2
View File
@@ -44,10 +44,10 @@ class RoomListHandler(BaseHandler):
super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache(
hs, "room_list"
hs.get_clock(), "room_list"
) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
self.remote_response_cache = ResponseCache(
hs, "remote_room_list", timeout_ms=30 * 1000
hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000
) # type: ResponseCache[Tuple[str, Optional[int], Optional[str], bool, Optional[str]]]
async def get_local_public_room_list(
+3
View File
@@ -456,6 +456,7 @@ class SsoHandler:
await self._auth_handler.complete_sso_login(
user_id,
auth_provider_id,
request,
client_redirect_url,
extra_login_attributes,
@@ -605,6 +606,7 @@ class SsoHandler:
default_display_name=attributes.display_name,
bind_emails=attributes.emails,
user_agent_ips=[(user_agent, ip_address)],
auth_provider_id=auth_provider_id,
)
await self._store.record_user_external_id(
@@ -886,6 +888,7 @@ class SsoHandler:
await self._auth_handler.complete_sso_login(
user_id,
session.auth_provider_id,
request,
session.client_redirect_url,
session.extra_login_attributes,
+1 -1
View File
@@ -258,7 +258,7 @@ class SyncHandler:
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.response_cache = ResponseCache(
hs, "sync"
hs.get_clock(), "sync"
) # type: ResponseCache[Tuple[Any, ...]]
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
+3 -2
View File
@@ -63,6 +63,7 @@ from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_u
from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import ISynapseReactor
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
@@ -199,7 +200,7 @@ class _IPBlacklistingResolver:
return r
@implementer(IReactorPluggableNameResolver)
@implementer(ISynapseReactor)
class BlacklistingReactorWrapper:
"""
A Reactor wrapper which will prevent DNS resolution to blacklisted IP
@@ -324,7 +325,7 @@ class SimpleHttpClient:
# filters out blacklisted IP addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper(
hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
)
) # type: ISynapseReactor
else:
self.reactor = hs.get_reactor()
@@ -35,6 +35,7 @@ from synapse.http.client import BlacklistingAgentWrapper
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import ISynapseReactor
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -68,7 +69,7 @@ class MatrixFederationAgent:
def __init__(
self,
reactor: IReactorCore,
reactor: ISynapseReactor,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
user_agent: bytes,
ip_blacklist: IPSet,
+4 -4
View File
@@ -59,7 +59,7 @@ from synapse.logging.opentracing import (
start_active_span,
tags,
)
from synapse.types import JsonDict
from synapse.types import ISynapseReactor, JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -237,14 +237,14 @@ class MatrixFederationHttpClient:
# addresses, to prevent DNS rebinding.
self.reactor = BlacklistingReactorWrapper(
hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
)
) # type: ISynapseReactor
user_agent = hs.version_string
if hs.config.user_agent_suffix:
user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
user_agent = user_agent.encode("ascii")
self.agent = MatrixFederationAgent(
federation_agent = MatrixFederationAgent(
self.reactor,
tls_client_options_factory,
user_agent,
@@ -254,7 +254,7 @@ class MatrixFederationHttpClient:
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
self.agent,
federation_agent,
ip_blacklist=hs.config.federation_ip_range_blacklist,
)
+27 -4
View File
@@ -203,11 +203,26 @@ class ModuleApi:
)
def generate_short_term_login_token(
self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
self,
user_id: str,
duration_in_ms: int = (2 * 60 * 1000),
auth_provider_id: str = "",
) -> str:
"""Generate a login token suitable for m.login.token authentication"""
"""Generate a login token suitable for m.login.token authentication
Args:
user_id: gives the ID of the user that the token is for
duration_in_ms: the time that the token will be valid for
auth_provider_id: the ID of the SSO IdP that the user used to authenticate
to get this token, if any. This is encoded in the token so that
/login can report stats on number of successful logins by IdP.
"""
return self._hs.get_macaroon_generator().generate_short_term_login_token(
user_id, duration_in_ms
user_id,
auth_provider_id,
duration_in_ms,
)
@defer.inlineCallbacks
@@ -276,6 +291,7 @@ class ModuleApi:
"""
self._auth_handler._complete_sso_login(
registered_user_id,
"<unknown>",
request,
client_redirect_url,
)
@@ -286,6 +302,7 @@ class ModuleApi:
request: SynapseRequest,
client_redirect_url: str,
new_user: bool = False,
auth_provider_id: str = "<unknown>",
):
"""Complete a SSO login by redirecting the user to a page to confirm whether they
want their access token sent to `client_redirect_url`, or redirect them to that
@@ -299,9 +316,15 @@ class ModuleApi:
redirect them directly if whitelisted).
new_user: set to true to use wording for the consent appropriate to a user
who has just registered.
auth_provider_id: the ID of the SSO IdP which was used to log in. This
is used to track counts of sucessful logins by IdP.
"""
await self._auth_handler.complete_sso_login(
registered_user_id, request, client_redirect_url, new_user=new_user
registered_user_id,
auth_provider_id,
request,
client_redirect_url,
new_user=new_user,
)
@defer.inlineCallbacks
+6 -3
View File
@@ -18,7 +18,7 @@ import logging
import re
import urllib
from inspect import signature
from typing import Dict, List, Tuple
from typing import TYPE_CHECKING, Dict, List, Tuple
from prometheus_client import Counter, Gauge
@@ -28,6 +28,9 @@ from synapse.logging.opentracing import inject_active_span_byte_dict, trace
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import random_string
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
_pending_outgoing_requests = Gauge(
@@ -88,10 +91,10 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
CACHE = True
RETRY_ON_TIMEOUT = True
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
if self.CACHE:
self.response_cache = ResponseCache(
hs, "repl." + self.NAME, timeout_ms=30 * 60 * 1000
hs.get_clock(), "repl." + self.NAME, timeout_ms=30 * 60 * 1000
) # type: ResponseCache[str]
# We reserve `instance_name` as a parameter to sending requests, so we
+1 -1
View File
@@ -328,6 +328,6 @@ def lazyConnection(
factory.continueTrying = reconnect
reactor = hs.get_reactor()
reactor.connectTCP(host, port, factory, 30)
reactor.connectTCP(host, port, factory, timeout=30, bindAddress=None)
return factory.handler
+9 -6
View File
@@ -12,13 +12,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Tuple
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
class PurgeRoomServlet(RestServlet):
@@ -36,16 +43,12 @@ class PurgeRoomServlet(RestServlet):
PATTERNS = admin_patterns("/purge_room$")
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.pagination_handler = hs.get_pagination_handler()
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
+14 -9
View File
@@ -12,17 +12,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.http.server import HttpServer
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import UserID
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
class SendServerNoticeServlet(RestServlet):
@@ -44,17 +51,13 @@ class SendServerNoticeServlet(RestServlet):
}
"""
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer): server
"""
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs)
self.snm = hs.get_server_notices_manager()
def register(self, json_resource):
def register(self, json_resource: HttpServer):
PATTERN = "/send_server_notice"
json_resource.register_paths(
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
@@ -66,7 +69,9 @@ class SendServerNoticeServlet(RestServlet):
self.__class__.__name__,
)
async def on_POST(self, request, txn_id=None):
async def on_POST(
self, request: SynapseRequest, txn_id: Optional[str] = None
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ("user_id", "content"))
@@ -90,7 +95,7 @@ class SendServerNoticeServlet(RestServlet):
return 200, {"event_id": event.event_id}
def on_PUT(self, request, txn_id):
def on_PUT(self, request: SynapseRequest, txn_id: str) -> Tuple[int, JsonDict]:
return self.txns.fetch_or_execute_request(
request, self.on_POST, request, txn_id
)
+9 -5
View File
@@ -219,6 +219,7 @@ class LoginRestServlet(RestServlet):
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
ratelimit: bool = True,
auth_provider_id: Optional[str] = None,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -234,6 +235,8 @@ class LoginRestServlet(RestServlet):
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
ratelimit: Whether to ratelimit the login request.
auth_provider_id: The SSO IdP the user used, if any (just used for the
prometheus metrics).
Returns:
result: Dictionary of account information after successful login.
@@ -256,7 +259,7 @@ class LoginRestServlet(RestServlet):
device_id = login_submission.get("device_id")
initial_display_name = login_submission.get("initial_device_display_name")
device_id, access_token = await self.registration_handler.register_device(
user_id, device_id, initial_display_name
user_id, device_id, initial_display_name, auth_provider_id=auth_provider_id
)
result = {
@@ -283,12 +286,13 @@ class LoginRestServlet(RestServlet):
"""
token = login_submission["token"]
auth_handler = self.auth_handler
user_id = await auth_handler.validate_short_term_login_token_and_get_user_id(
token
)
res = await auth_handler.validate_short_term_login_token(token)
return await self._complete_login(
user_id, login_submission, self.auth_handler._sso_login_callback
res.user_id,
login_submission,
self.auth_handler._sso_login_callback,
auth_provider_id=res.auth_provider_id,
)
async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]:
+2 -3
View File
@@ -36,7 +36,6 @@ from typing import (
cast,
)
import twisted.internet.base
import twisted.internet.tcp
from twisted.internet import defer
from twisted.mail.smtp import sendmail
@@ -130,7 +129,7 @@ from synapse.server_notices.worker_server_notices_sender import (
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import Databases, DataStore, Storage
from synapse.streams.events import EventSources
from synapse.types import DomainSpecificString
from synapse.types import DomainSpecificString, ISynapseReactor
from synapse.util import Clock
from synapse.util.distributor import Distributor
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -291,7 +290,7 @@ class HomeServer(metaclass=abc.ABCMeta):
for i in self.REQUIRED_ON_BACKGROUND_TASK_STARTUP:
getattr(self, "get_" + i + "_handler")()
def get_reactor(self) -> twisted.internet.base.ReactorBase:
def get_reactor(self) -> ISynapseReactor:
"""
Fetch the Twisted reactor in use by this HomeServer.
"""
+16
View File
@@ -36,6 +36,14 @@ import attr
from signedjson.key import decode_verify_key_bytes
from six.moves import filter
from unpaddedbase64 import decode_base64
from zope.interface import Interface
from twisted.internet.interfaces import (
IReactorCore,
IReactorPluggableNameResolver,
IReactorTCP,
IReactorTime,
)
from synapse.api.errors import Codes, SynapseError
from synapse.util.stringutils import parse_and_validate_server_name
@@ -68,6 +76,14 @@ MutableStateMap = MutableMapping[StateKey, T]
JsonDict = Dict[str, Any]
# Note that this seems to require inheriting *directly* from Interface in order
# for mypy-zope to realize it is an interface.
class ISynapseReactor(
IReactorTCP, IReactorPluggableNameResolver, IReactorTime, IReactorCore, Interface
):
"""The interfaces necessary for Synapse to function."""
class Requester(
namedtuple(
"Requester",
+18 -8
View File
@@ -76,11 +76,16 @@ class ObservableDeferred:
def callback(r):
object.__setattr__(self, "_result", (True, r))
while self._observers:
observer = self._observers.pop()
try:
# TODO: Handle errors here.
self._observers.pop().callback(r)
except Exception:
pass
observer.callback(r)
except Exception as e:
logger.exception(
"%r threw an exception on .callback(%r), ignoring...",
observer,
r,
exc_info=e,
)
return r
def errback(f):
@@ -90,11 +95,16 @@ class ObservableDeferred:
# traces when we `await` on one of the observer deferreds.
f.value.__failure__ = f
observer = self._observers.pop()
try:
# TODO: Handle errors here.
self._observers.pop().errback(f)
except Exception:
pass
observer.errback(f)
except Exception as e:
logger.exception(
"%r threw an exception on .errback(%r), ignoring...",
observer,
f,
exc_info=e,
)
if consumeErrors:
return None
+4 -6
View File
@@ -13,17 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
from typing import Any, Callable, Dict, Generic, Optional, TypeVar
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.util import Clock
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
T = TypeVar("T")
@@ -37,11 +35,11 @@ class ResponseCache(Generic[T]):
used rather than trying to compute a new response.
"""
def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
def __init__(self, clock: Clock, name: str, timeout_ms: float = 0):
# Requests that haven't finished yet.
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
self.clock = hs.get_clock()
self.clock = clock
self.timeout_sec = timeout_ms / 1000.0
self._name = name
+89
View File
@@ -0,0 +1,89 @@
# -*- coding: utf-8 -*-
# Copyright 2020 Quentin Gliech
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for manipulating macaroons"""
from typing import Callable, Optional
import pymacaroons
from pymacaroons.exceptions import MacaroonVerificationFailedException
def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
"""Extracts a caveat value from a macaroon token.
Checks that there is exactly one caveat of the form "key = <val>" in the macaroon,
and returns the extracted value.
Args:
macaroon: the token
key: the key of the caveat to extract
Returns:
The extracted value
Raises:
MacaroonVerificationFailedException: if there are conflicting values for the
caveat in the macaroon, or if the caveat was not found in the macaroon.
"""
prefix = key + " = "
result = None # type: Optional[str]
for caveat in macaroon.caveats:
if not caveat.caveat_id.startswith(prefix):
continue
val = caveat.caveat_id[len(prefix) :]
if result is None:
# first time we found this caveat: record the value
result = val
elif val != result:
# on subsequent occurrences, raise if the value is different.
raise MacaroonVerificationFailedException(
"Conflicting values for caveat " + key
)
if result is not None:
return result
# If the caveat is not there, we raise a MacaroonVerificationFailedException.
# Note that it is insecure to generate a macaroon without all the caveats you
# might need (because there is nothing stopping people from adding extra caveats),
# so if the caveat isn't there, something odd must be going on.
raise MacaroonVerificationFailedException("No %s caveat in macaroon" % (key,))
def satisfy_expiry(v: pymacaroons.Verifier, get_time_ms: Callable[[], int]) -> None:
"""Make a macaroon verifier which accepts 'time' caveats
Builds a caveat verifier which will accept unexpired 'time' caveats, and adds it to
the given macaroon verifier.
Args:
v: the macaroon verifier
get_time_ms: a callable which will return the timestamp after which the caveat
should be considered expired. Normally the current time.
"""
def verify_expiry_caveat(caveat: str):
time_msec = get_time_ms()
prefix = "time < "
if not caveat.startswith(prefix):
return False
expiry = int(caveat[len(prefix) :])
return time_msec < expiry
v.satisfy_general(verify_expiry_caveat)
+29 -20
View File
@@ -68,38 +68,45 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.verify(macaroon, self.hs.config.macaroon_secret_key)
def test_short_term_login_token_gives_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
user_id = self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
token = self.macaroon_generator.generate_short_term_login_token(
"a_user", "", 5000
)
self.assertEqual("a_user", user_id)
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
self.assertEqual("a_user", res.user_id)
self.assertEqual("", res.auth_provider_id)
# when we advance the clock, the token should be rejected
self.reactor.advance(6)
self.get_failure(
self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
self.auth_handler.validate_short_term_login_token(token),
AuthError,
)
def test_short_term_login_token_gives_auth_provider(self):
token = self.macaroon_generator.generate_short_term_login_token(
"a_user", auth_provider_id="my_idp"
)
res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
self.assertEqual("a_user", res.user_id)
self.assertEqual("my_idp", res.auth_provider_id)
def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
token = self.macaroon_generator.generate_short_term_login_token(
"a_user", "", 5000
)
macaroon = pymacaroons.Macaroon.deserialize(token)
user_id = self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
)
res = self.get_success(
self.auth_handler.validate_short_term_login_token(macaroon.serialize())
)
self.assertEqual("a_user", user_id)
self.assertEqual("a_user", res.user_id)
# add another "user_id" caveat, which might allow us to override the
# user_id.
macaroon.add_first_party_caveat("user_id = b_user")
self.get_failure(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
macaroon.serialize()
),
self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
AuthError,
)
@@ -113,7 +120,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
@@ -135,7 +142,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(self.large_number_of_users)
)
self.get_failure(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
),
ResourceLimitError,
@@ -159,7 +166,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
self.get_failure(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
),
ResourceLimitError,
@@ -175,7 +182,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
@@ -197,11 +204,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
return_value=make_awaitable(self.small_number_of_users)
)
self.get_success(
self.auth_handler.validate_short_term_login_token_and_get_user_id(
self.auth_handler.validate_short_term_login_token(
self._get_macaroon().serialize()
)
)
def _get_macaroon(self):
token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
token = self.macaroon_generator.generate_short_term_login_token(
"user_a", "", 5000
)
return pymacaroons.Macaroon.deserialize(token)
+5 -5
View File
@@ -66,7 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=True
"@test_user:test", "cas", request, "redirect_uri", None, new_user=True
)
def test_map_cas_user_to_existing_user(self):
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=False
"@test_user:test", "cas", request, "redirect_uri", None, new_user=False
)
# Subsequent calls should map to the same mxid.
@@ -98,7 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase):
self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
)
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=False
"@test_user:test", "cas", request, "redirect_uri", None, new_user=False
)
def test_map_cas_user_to_invalid_localpart(self):
@@ -116,7 +116,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
"@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
)
@override_config(
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=True
"@test_user:test", "cas", request, "redirect_uri", None, new_user=True
)
+16 -20
View File
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Optional
from urllib.parse import parse_qs, urlparse
from mock import ANY, Mock, patch
@@ -23,6 +22,7 @@ import pymacaroons
from synapse.handlers.sso import MappingException
from synapse.server import HomeServer
from synapse.types import UserID
from synapse.util.macaroons import get_value_from_macaroon
from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
@@ -360,15 +360,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(name, b"oidc_session")
macaroon = pymacaroons.Macaroon.deserialize(cookie)
state = self.handler._token_generator._get_value_from_macaroon(
macaroon, "state"
)
nonce = self.handler._token_generator._get_value_from_macaroon(
macaroon, "nonce"
)
redirect = self.handler._token_generator._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
state = get_value_from_macaroon(macaroon, "state")
nonce = get_value_from_macaroon(macaroon, "nonce")
redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
self.assertEqual(params["state"], [state])
self.assertEqual(params["nonce"], [nonce])
@@ -434,7 +428,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, request, client_redirect_url, None, new_user=True
expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -465,7 +459,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, request, client_redirect_url, None, new_user=False
expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
)
self.provider._exchange_code.assert_called_once_with(code)
self.provider._parse_id_token.assert_not_called()
@@ -651,6 +645,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler.complete_sso_login.assert_called_once_with(
"@foo:test",
"oidc",
request,
client_redirect_url,
{"phone": "1234567"},
@@ -668,7 +663,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", ANY, ANY, None, new_user=True
"@test_user:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -679,7 +674,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user_2:test", ANY, ANY, None, new_user=True
"@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -716,14 +711,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, None, new_user=False
user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, None, new_user=False
user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
@@ -738,7 +733,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
user.to_string(), ANY, ANY, None, new_user=False
user.to_string(), "oidc", ANY, ANY, None, new_user=False
)
auth_handler.complete_sso_login.reset_mock()
@@ -774,7 +769,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
auth_handler.complete_sso_login.assert_called_once_with(
"@TEST_USER_2:test", ANY, ANY, None, new_user=False
"@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
)
def test_map_userinfo_to_invalid_localpart(self):
@@ -810,7 +805,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user1:test", ANY, ANY, None, new_user=True
"@test_user1:test", "oidc", ANY, ANY, None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -866,7 +861,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
state: str,
nonce: str,
client_redirect_url: str,
ui_auth_session_id: Optional[str] = None,
ui_auth_session_id: str = "",
) -> str:
from synapse.handlers.oidc_handler import OidcSessionData
@@ -909,6 +904,7 @@ async def _make_callback_with_userinfo(
idp_id="oidc",
nonce="nonce",
client_redirect_url=client_redirect_url,
ui_auth_session_id="",
),
)
request = _build_callback_request("code", state, session)
+5 -5
View File
@@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=True
"@test_user:test", "saml", request, "redirect_uri", None, new_user=True
)
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "", None, new_user=False
"@test_user:test", "saml", request, "", None, new_user=False
)
# Subsequent calls should map to the same mxid.
@@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
self.handler._handle_authn_response(request, saml_response, "")
)
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "", None, new_user=False
"@test_user:test", "saml", request, "", None, new_user=False
)
def test_map_saml_response_to_invalid_localpart(self):
@@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# test_user is already taken, so test_user1 gets registered instead.
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user1:test", request, "", None, new_user=True
"@test_user1:test", "saml", request, "", None, new_user=True
)
auth_handler.complete_sso_login.reset_mock()
@@ -310,7 +310,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
# check that the auth handler got called as expected
auth_handler.complete_sso_login.assert_called_once_with(
"@test_user:test", request, "redirect_uri", None, new_user=True
"@test_user:test", "saml", request, "redirect_uri", None, new_user=True
)
+45 -26
View File
@@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.api.errors import NotFoundError
from synapse.api.errors import NotFoundError, SynapseError
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@@ -33,9 +31,12 @@ class PurgeTests(HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.room_id = self.helper.create_room_as(self.user_id)
def test_purge(self):
self.store = hs.get_datastore()
self.storage = self.hs.get_storage()
def test_purge_history(self):
"""
Purging a room will delete everything before the topological point.
Purging a room history will delete everything before the topological point.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
@@ -43,30 +44,27 @@ class PurgeTests(HomeserverTestCase):
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
store = self.hs.get_datastore()
storage = self.hs.get_storage()
# Get the topological token
token = self.get_success(
store.get_topological_token_for_event(last["event_id"])
self.store.get_topological_token_for_event(last["event_id"])
)
token_str = self.get_success(token.to_string(self.hs.get_datastore()))
# Purge everything before this topological token
self.get_success(
storage.purge_events.purge_history(self.room_id, token_str, True)
self.storage.purge_events.purge_history(self.room_id, token_str, True)
)
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
self.get_failure(store.get_event(first["event_id"]), NotFoundError)
self.get_failure(store.get_event(second["event_id"]), NotFoundError)
self.get_failure(store.get_event(third["event_id"]), NotFoundError)
self.get_success(store.get_event(last["event_id"]))
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
self.get_failure(self.store.get_event(second["event_id"]), NotFoundError)
self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
self.get_success(self.store.get_event(last["event_id"]))
def test_purge_wont_delete_extrems(self):
def test_purge_history_wont_delete_extrems(self):
"""
Purging a room will delete everything before the topological point.
Purging a room history will delete everything before the topological point.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
@@ -74,22 +72,43 @@ class PurgeTests(HomeserverTestCase):
third = self.helper.send(self.room_id, body="test3")
last = self.helper.send(self.room_id, body="test4")
storage = self.hs.get_datastore()
# Set the topological token higher than it should be
token = self.get_success(
storage.get_topological_token_for_event(last["event_id"])
self.store.get_topological_token_for_event(last["event_id"])
)
event = "t{}-{}".format(token.topological + 1, token.stream + 1)
# Purge everything before this topological token
purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
self.pump()
f = self.failureResultOf(purge)
f = self.get_failure(
self.storage.purge_events.purge_history(self.room_id, event, True),
SynapseError,
)
self.assertIn("greater than forward", f.value.args[0])
# Try and get the events
self.get_success(storage.get_event(first["event_id"]))
self.get_success(storage.get_event(second["event_id"]))
self.get_success(storage.get_event(third["event_id"]))
self.get_success(storage.get_event(last["event_id"]))
self.get_success(self.store.get_event(first["event_id"]))
self.get_success(self.store.get_event(second["event_id"]))
self.get_success(self.store.get_event(third["event_id"]))
self.get_success(self.store.get_event(last["event_id"]))
def test_purge_room(self):
"""
Purging a room will delete everything about it.
"""
# Send four messages to the room
first = self.helper.send(self.room_id, body="test1")
# Get the current room state.
state_handler = self.hs.get_state_handler()
create_event = self.get_success(
state_handler.get_current_state(self.room_id, "m.room.create", "")
)
self.assertIsNotNone(create_event)
# Purge everything before this topological token
self.get_success(self.storage.purge_events.purge_room(self.room_id))
# The events aren't found.
self.store._invalidate_get_event_cache(create_event.event_id)
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
+131
View File
@@ -0,0 +1,131 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.caches.response_cache import ResponseCache
from tests.server import get_clock
from tests.unittest import TestCase
class DeferredCacheTestCase(TestCase):
"""
A TestCase class for ResponseCache.
The test-case function naming has some logic to it in it's parts, here's some notes about it:
wait: Denotes tests that have an element of "waiting" before its wrapped result becomes available
(Generally these just use .delayed_return instead of .instant_return in it's wrapped call.)
expire: Denotes tests that test expiry after assured existence.
(These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock)
"""
def setUp(self):
self.reactor, self.clock = get_clock()
def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
return ResponseCache(self.clock, name, timeout_ms=ms)
@staticmethod
async def instant_return(o: str) -> str:
return o
async def delayed_return(self, o: str) -> str:
await self.clock.sleep(1)
return o
def test_cache_hit(self):
cache = self.with_cache("keeping_cache", ms=9001)
expected_result = "howdy"
wrap_d = cache.wrap(0, self.instant_return, expected_result)
self.assertEqual(
expected_result,
self.successResultOf(wrap_d),
"initial wrap result should be the same",
)
self.assertEqual(
expected_result,
self.successResultOf(cache.get(0)),
"cache should have the result",
)
def test_cache_miss(self):
cache = self.with_cache("trashing_cache", ms=0)
expected_result = "howdy"
wrap_d = cache.wrap(0, self.instant_return, expected_result)
self.assertEqual(
expected_result,
self.successResultOf(wrap_d),
"initial wrap result should be the same",
)
self.assertIsNone(cache.get(0), "cache should not have the result now")
def test_cache_expire(self):
cache = self.with_cache("short_cache", ms=1000)
expected_result = "howdy"
wrap_d = cache.wrap(0, self.instant_return, expected_result)
self.assertEqual(expected_result, self.successResultOf(wrap_d))
self.assertEqual(
expected_result,
self.successResultOf(cache.get(0)),
"cache should still have the result",
)
# cache eviction timer is handled
self.reactor.pump((2,))
self.assertIsNone(cache.get(0), "cache should not have the result now")
def test_cache_wait_hit(self):
cache = self.with_cache("neutral_cache")
expected_result = "howdy"
wrap_d = cache.wrap(0, self.delayed_return, expected_result)
self.assertNoResult(wrap_d)
# function wakes up, returns result
self.reactor.pump((2,))
self.assertEqual(expected_result, self.successResultOf(wrap_d))
def test_cache_wait_expire(self):
cache = self.with_cache("medium_cache", ms=3000)
expected_result = "howdy"
wrap_d = cache.wrap(0, self.delayed_return, expected_result)
self.assertNoResult(wrap_d)
# stop at 1 second to callback cache eviction callLater at that time, then another to set time at 2
self.reactor.pump((1, 1))
self.assertEqual(expected_result, self.successResultOf(wrap_d))
self.assertEqual(
expected_result,
self.successResultOf(cache.get(0)),
"cache should still have the result",
)
# (1 + 1 + 2) > 3.0, cache eviction timer is handled
self.reactor.pump((2,))
self.assertIsNone(cache.get(0), "cache should not have the result now")