1
0

Compare commits

...

10 Commits

Author SHA1 Message Date
Erik Johnston
f42fa8b15b Correctly set authenticated_entity when creating ad hoc requesters 2020-10-21 16:53:40 +01:00
Erik Johnston
d8902d4bd5 MAU limits are based off of the *authenticated* user 2020-10-21 16:53:40 +01:00
Erik Johnston
a4b03bf183 Privacy policy applies to authenticated entity 2020-10-21 16:53:40 +01:00
Erik Johnston
8d08cf75f4 Add admin API for logging in as a user
Fixes #6054.
2020-10-21 16:53:40 +01:00
Erik Johnston
9edb5b369e Add concept of authenticated_entity vs target_user 2020-10-21 16:53:40 +01:00
Erik Johnston
c238a54357 Make get_user_by_access_token return a proper type 2020-10-21 16:53:40 +01:00
Erik Johnston
8620b27113 Add typing info to registration 2020-10-21 10:50:15 +01:00
Erik Johnston
f0bcf6f578 Add registration store to mypy 2020-10-21 10:41:12 +01:00
Erik Johnston
03076254e3 Format SQL 2020-10-21 10:32:37 +01:00
Erik Johnston
d9c0b19b30 Don't instansiate Requester directly 2020-10-20 17:08:29 +01:00
39 changed files with 672 additions and 221 deletions

View File

@@ -341,6 +341,41 @@ The following fields are returned in the JSON response body:
- ``total`` - Number of rooms. - ``total`` - Number of rooms.
Login as a user
===============
Get an access token that can be used to authenticate as that user. Useful for
when admins wish to do actions on behalf of a user.
The API is::
PUT /_synapse/admin/v1/users/<user_id>/login
{}
An optional ``valid_until_ms`` field can be specified in the request body as an
integer timestamp that specifies when the token should expire. By default tokens
do not expire.
A response body like the following is returned:
.. code:: json
{
"access_token": "<opaque_access_token_string>"
}
This API does *not* generate a new device for the user, and so will not appear
their ``/devices`` list, and in general the target user should not be able to
tell they have been logged in as.
To expire the token call the standard ``/logout`` API with the token.
Note: The token will expire if the *admin* user calls ``/logout/all`` from any
of their devices, but the token will *not* expire if the target user does the
same.
User devices User devices
============ ============

View File

@@ -55,6 +55,7 @@ files =
synapse/spam_checker_api, synapse/spam_checker_api,
synapse/state, synapse/state,
synapse/storage/databases/main/events.py, synapse/storage/databases/main/events.py,
synapse/storage/databases/main/registration.py,
synapse/storage/databases/main/stream.py, synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py, synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py, synapse/storage/database.py,

View File

@@ -33,6 +33,7 @@ from synapse.api.errors import (
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging import opentracing as opentracing from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID from synapse.types import StateMap, UserID
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@@ -194,6 +195,7 @@ class Auth:
if user_id: if user_id:
request.authenticated_entity = user_id request.authenticated_entity = user_id
opentracing.set_tag("authenticated_entity", user_id) opentracing.set_tag("authenticated_entity", user_id)
opentracing.set_tag("target_user", user_id)
opentracing.set_tag("appservice_id", app_service.id) opentracing.set_tag("appservice_id", app_service.id)
if ip_addr and self._track_appservice_user_ips: if ip_addr and self._track_appservice_user_ips:
@@ -210,26 +212,25 @@ class Auth:
user_info = await self.get_user_by_access_token( user_info = await self.get_user_by_access_token(
access_token, rights, allow_expired=allow_expired access_token, rights, allow_expired=allow_expired
) )
user = user_info["user"] user = UserID.from_string(user_info.user_id)
token_id = user_info["token_id"] token_id = user_info.token_id
is_guest = user_info["is_guest"] is_guest = user_info.is_guest
shadow_banned = user_info["shadow_banned"] shadow_banned = user_info.shadow_banned
# Deny the request if the user account has expired. # Deny the request if the user account has expired.
if self._account_validity.enabled and not allow_expired: if self._account_validity.enabled and not allow_expired:
user_id = user.to_string() if await self.store.is_account_expired(
if await self.store.is_account_expired(user_id, self.clock.time_msec()): user_info.user_id, self.clock.time_msec()
):
raise AuthError( raise AuthError(
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT 403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
) )
# device_id may not be present if get_user_by_access_token has been device_id = user_info.device_id
# stubbed out.
device_id = user_info.get("device_id")
if user and access_token and ip_addr: if access_token and ip_addr:
await self.store.insert_client_ip( await self.store.insert_client_ip(
user_id=user.to_string(), user_id=user_info.token_owner,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
user_agent=user_agent, user_agent=user_agent,
@@ -243,8 +244,10 @@ class Auth:
errcode=Codes.GUEST_ACCESS_FORBIDDEN, errcode=Codes.GUEST_ACCESS_FORBIDDEN,
) )
request.authenticated_entity = user.to_string() request.authenticated_entity = user_info.token_owner
opentracing.set_tag("authenticated_entity", user.to_string()) request.target_user = user_info.user_id
opentracing.set_tag("authenticated_entity", user_info.token_owner)
opentracing.set_tag("target_user", user_info.user_id)
if device_id: if device_id:
opentracing.set_tag("device_id", device_id) opentracing.set_tag("device_id", device_id)
@@ -255,6 +258,7 @@ class Auth:
shadow_banned, shadow_banned,
device_id, device_id,
app_service=app_service, app_service=app_service,
authenticated_entity=user_info.token_owner,
) )
except KeyError: except KeyError:
raise MissingClientTokenError() raise MissingClientTokenError()
@@ -286,7 +290,7 @@ class Auth:
async def get_user_by_access_token( async def get_user_by_access_token(
self, token: str, rights: str = "access", allow_expired: bool = False, self, token: str, rights: str = "access", allow_expired: bool = False,
) -> dict: ) -> TokenLookupResult:
""" Validate access token and get user_id from it """ Validate access token and get user_id from it
Args: Args:
@@ -295,13 +299,7 @@ class Auth:
allow this allow this
allow_expired: If False, raises an InvalidClientTokenError allow_expired: If False, raises an InvalidClientTokenError
if the token is expired if the token is expired
Returns:
dict that includes:
`user` (UserID)
`is_guest` (bool)
`shadow_banned` (bool)
`token_id` (int|None): access token id. May be None if guest
`device_id` (str|None): device corresponding to access token
Raises: Raises:
InvalidClientTokenError if a user by that token exists, but the token is InvalidClientTokenError if a user by that token exists, but the token is
expired expired
@@ -311,9 +309,9 @@ class Auth:
if rights == "access": if rights == "access":
# first look in the database # first look in the database
r = await self._look_up_user_by_access_token(token) r = await self.store.get_user_by_access_token(token)
if r: if r:
valid_until_ms = r["valid_until_ms"] valid_until_ms = r.valid_until_ms
if ( if (
not allow_expired not allow_expired
and valid_until_ms is not None and valid_until_ms is not None
@@ -330,7 +328,6 @@ class Auth:
# otherwise it needs to be a valid macaroon # otherwise it needs to be a valid macaroon
try: try:
user_id, guest = self._parse_and_validate_macaroon(token, rights) user_id, guest = self._parse_and_validate_macaroon(token, rights)
user = UserID.from_string(user_id)
if rights == "access": if rights == "access":
if not guest: if not guest:
@@ -356,23 +353,17 @@ class Auth:
raise InvalidClientTokenError( raise InvalidClientTokenError(
"Guest access token used for regular user" "Guest access token used for regular user"
) )
ret = {
"user": user, ret = TokenLookupResult(
"is_guest": True, user_id=user_id,
"shadow_banned": False, is_guest=True,
"token_id": None,
# all guests get the same device id # all guests get the same device id
"device_id": GUEST_DEVICE_ID, device_id=GUEST_DEVICE_ID,
} )
elif rights == "delete_pusher": elif rights == "delete_pusher":
# We don't store these tokens in the database # We don't store these tokens in the database
ret = {
"user": user, ret = TokenLookupResult(user_id=user_id, is_guest=False)
"is_guest": False,
"shadow_banned": False,
"token_id": None,
"device_id": None,
}
else: else:
raise RuntimeError("Unknown rights setting %s", rights) raise RuntimeError("Unknown rights setting %s", rights)
return ret return ret
@@ -481,24 +472,6 @@ class Auth:
now = self.hs.get_clock().time_msec() now = self.hs.get_clock().time_msec()
return now < expiry return now < expiry
async def _look_up_user_by_access_token(self, token):
ret = await self.store.get_user_by_access_token(token)
if not ret:
return None
# we use ret.get() below because *lots* of unit tests stub out
# get_user_by_access_token in a way where it only returns a couple of
# the fields.
user_info = {
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
"is_guest": False,
"shadow_banned": ret.get("shadow_banned"),
"device_id": ret.get("device_id"),
"valid_until_ms": ret.get("valid_until_ms"),
}
return user_info
def get_appservice_by_req(self, request): def get_appservice_by_req(self, request):
token = self.get_access_token_from_request(request) token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token) service = self.store.get_app_service_by_token(token)

View File

@@ -14,10 +14,12 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Optional
from synapse.api.constants import LimitBlockingTypes, UserTypes from synapse.api.constants import LimitBlockingTypes, UserTypes
from synapse.api.errors import Codes, ResourceLimitError from synapse.api.errors import Codes, ResourceLimitError
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
from synapse.types import Requester
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -33,24 +35,41 @@ class AuthBlocking:
self._max_mau_value = hs.config.max_mau_value self._max_mau_value = hs.config.max_mau_value
self._limit_usage_by_mau = hs.config.limit_usage_by_mau self._limit_usage_by_mau = hs.config.limit_usage_by_mau
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
self._server_name = hs.hostname
async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None): async def check_auth_blocking(
self,
user_id: Optional[str] = None,
threepid: Optional[dict] = None,
user_type: Optional[str] = None,
requester: Optional[Requester] = None,
):
"""Checks if the user should be rejected for some external reason, """Checks if the user should be rejected for some external reason,
such as monthly active user limiting or global disable flag such as monthly active user limiting or global disable flag
Args: Args:
user_id(str|None): If present, checks for presence against existing user_id: If present, checks for presence against existing
MAU cohort MAU cohort
threepid(dict|None): If present, checks for presence against configured threepid: If present, checks for presence against configured
reserved threepid. Used in cases where the user is trying register reserved threepid. Used in cases where the user is trying register
with a MAU blocked server, normally they would be rejected but their with a MAU blocked server, normally they would be rejected but their
threepid is on the reserved list. user_id and threepid is on the reserved list. user_id and
threepid should never be set at the same time. threepid should never be set at the same time.
user_type(str|None): If present, is used to decide whether to check against user_type: If present, is used to decide whether to check against
certain blocking reasons like MAU. certain blocking reasons like MAU.
requester: If present, and the authenticated entity is a user, checks for
presence against existing MAU cohort.
""" """
if requester:
if requester.authenticated_entity.startswith("@"):
user_id = requester.authenticated_entity
elif requester.authenticated_entity == self._server_name:
# We never block the server from doing actions on behalf of
# users.
return
# Never fail an auth check for the server notices users or support user # Never fail an auth check for the server notices users or support user
# This can be a problem where event creation is prohibited due to blocking # This can be a problem where event creation is prohibited due to blocking

View File

@@ -169,7 +169,9 @@ class BaseHandler:
# and having homeservers have their own users leave keeps more # and having homeservers have their own users leave keeps more
# of that decision-making and control local to the guest-having # of that decision-making and control local to the guest-having
# homeserver. # homeserver.
requester = synapse.types.create_requester(target_user, is_guest=True) requester = synapse.types.create_requester(
target_user, is_guest=True, authenticated_entity=self.server_name
)
handler = self.hs.get_room_member_handler() handler = self.hs.get_room_member_handler()
await handler.update_membership( await handler.update_membership(
requester, requester,

View File

@@ -686,8 +686,12 @@ class AuthHandler(BaseHandler):
} }
async def get_access_token_for_user_id( async def get_access_token_for_user_id(
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] self,
): user_id: str,
device_id: Optional[str],
valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
) -> str:
""" """
Creates a new access token for the user with the given user ID. Creates a new access token for the user with the given user ID.
@@ -713,13 +717,25 @@ class AuthHandler(BaseHandler):
fmt_expiry = time.strftime( fmt_expiry = time.strftime(
" until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0) " until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0)
) )
logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
if puppets_user_id:
logger.info(
"Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry
)
else:
logger.info(
"Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
)
await self.auth.check_auth_blocking(user_id) await self.auth.check_auth_blocking(user_id)
access_token = self.macaroon_gen.generate_access_token(user_id) access_token = self.macaroon_gen.generate_access_token(user_id)
await self.store.add_access_token_to_user( await self.store.add_access_token_to_user(
user_id, access_token, device_id, valid_until_ms user_id=user_id,
token=access_token,
device_id=device_id,
valid_until_ms=valid_until_ms,
puppets_user_id=puppets_user_id,
) )
# the device *should* have been registered before we got here; however, # the device *should* have been registered before we got here; however,
@@ -984,17 +1000,17 @@ class AuthHandler(BaseHandler):
# This might return an awaitable, if it does block the log out # This might return an awaitable, if it does block the log out
# until it completes. # until it completes.
result = provider.on_logged_out( result = provider.on_logged_out(
user_id=str(user_info["user"]), user_id=user_info.user_id,
device_id=user_info["device_id"], device_id=user_info.device_id,
access_token=access_token, access_token=access_token,
) )
if inspect.isawaitable(result): if inspect.isawaitable(result):
await result await result
# delete pushers associated with this access token # delete pushers associated with this access token
if user_info["token_id"] is not None: if user_info.token_id is not None:
await self.hs.get_pusherpool().remove_pushers_by_access_token( await self.hs.get_pusherpool().remove_pushers_by_access_token(
str(user_info["user"]), (user_info["token_id"],) user_info.user_id, (user_info.token_id,)
) )
async def delete_access_tokens_for_user( async def delete_access_tokens_for_user(

View File

@@ -39,6 +39,7 @@ class DeactivateAccountHandler(BaseHandler):
self._room_member_handler = hs.get_room_member_handler() self._room_member_handler = hs.get_room_member_handler()
self._identity_handler = hs.get_identity_handler() self._identity_handler = hs.get_identity_handler()
self.user_directory_handler = hs.get_user_directory_handler() self.user_directory_handler = hs.get_user_directory_handler()
self._server_name = hs.hostname
# Flag that indicates whether the process to part users from rooms is running # Flag that indicates whether the process to part users from rooms is running
self._user_parter_running = False self._user_parter_running = False
@@ -152,7 +153,7 @@ class DeactivateAccountHandler(BaseHandler):
for room in pending_invites: for room in pending_invites:
try: try:
await self._room_member_handler.update_membership( await self._room_member_handler.update_membership(
create_requester(user), create_requester(user, authenticated_entity=self._server_name),
user, user,
room.room_id, room.room_id,
"leave", "leave",
@@ -208,7 +209,7 @@ class DeactivateAccountHandler(BaseHandler):
logger.info("User parter parting %r from %r", user_id, room_id) logger.info("User parter parting %r from %r", user_id, room_id)
try: try:
await self._room_member_handler.update_membership( await self._room_member_handler.update_membership(
create_requester(user), create_requester(user, authenticated_entity=self._server_name),
user, user,
room_id, room_id,
"leave", "leave",

View File

@@ -473,7 +473,7 @@ class EventCreationHandler:
Returns: Returns:
Tuple of created event, Context Tuple of created event, Context
""" """
await self.auth.check_auth_blocking(requester.user.to_string()) await self.auth.check_auth_blocking(requester=requester)
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "": if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version = event_dict["content"]["room_version"] room_version = event_dict["content"]["room_version"]
@@ -620,7 +620,13 @@ class EventCreationHandler:
if requester.app_service is not None: if requester.app_service is not None:
return return
user_id = requester.user.to_string() user_id = requester.authenticated_entity
if not user_id.startswith("@"):
# The authenticated entity might not be a user, e.g. if it's the
# server puppetting the user.
return
user = UserID.from_string(user_id)
# exempt the system notices user # exempt the system notices user
if ( if (
@@ -640,9 +646,7 @@ class EventCreationHandler:
if u["consent_version"] == self.config.user_consent_version: if u["consent_version"] == self.config.user_consent_version:
return return
consent_uri = self._consent_uri_builder.build_user_consent_uri( consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
requester.user.localpart
)
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri} msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri) raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
@@ -1271,7 +1275,9 @@ class EventCreationHandler:
for user_id in members: for user_id in members:
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(user_id):
continue continue
requester = create_requester(user_id) requester = create_requester(
user_id, authenticated_entity=self._server_name
)
try: try:
event, context = await self.create_event( event, context = await self.create_event(
requester, requester,

View File

@@ -183,7 +183,9 @@ class ProfileHandler(BaseHandler):
# the join event to update the displayname in the rooms. # the join event to update the displayname in the rooms.
# This must be done by the target user himself. # This must be done by the target user himself.
if by_admin: if by_admin:
requester = create_requester(target_user) requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity,
)
await self.store.set_profile_displayname(target_user.localpart, new_displayname) await self.store.set_profile_displayname(target_user.localpart, new_displayname)
@@ -255,7 +257,9 @@ class ProfileHandler(BaseHandler):
# Same like set_displayname # Same like set_displayname
if by_admin: if by_admin:
requester = create_requester(target_user) requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
)
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url) await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)

View File

@@ -52,6 +52,7 @@ class RegistrationHandler(BaseHandler):
self.ratelimiter = hs.get_registration_ratelimiter() self.ratelimiter = hs.get_registration_ratelimiter()
self.macaroon_gen = hs.get_macaroon_generator() self.macaroon_gen = hs.get_macaroon_generator()
self._server_notices_mxid = hs.config.server_notices_mxid self._server_notices_mxid = hs.config.server_notices_mxid
self._server_name = hs.hostname
self.spam_checker = hs.get_spam_checker() self.spam_checker = hs.get_spam_checker()
@@ -115,7 +116,10 @@ class RegistrationHandler(BaseHandler):
400, "User ID already taken.", errcode=Codes.USER_IN_USE 400, "User ID already taken.", errcode=Codes.USER_IN_USE
) )
user_data = await self.auth.get_user_by_access_token(guest_access_token) user_data = await self.auth.get_user_by_access_token(guest_access_token)
if not user_data["is_guest"] or user_data["user"].localpart != localpart: if (
not user_data.is_guest
or UserID.from_string(user_data.user_id).localpart != localpart
):
raise AuthError( raise AuthError(
403, 403,
"Cannot register taken user ID without valid guest " "Cannot register taken user ID without valid guest "
@@ -314,7 +318,8 @@ class RegistrationHandler(BaseHandler):
requires_join = False requires_join = False
if self.hs.config.registration.auto_join_user_id: if self.hs.config.registration.auto_join_user_id:
fake_requester = create_requester( fake_requester = create_requester(
self.hs.config.registration.auto_join_user_id self.hs.config.registration.auto_join_user_id,
authenticated_entity=self._server_name,
) )
# If the room requires an invite, add the user to the list of invites. # If the room requires an invite, add the user to the list of invites.
@@ -326,7 +331,9 @@ class RegistrationHandler(BaseHandler):
# being necessary this will occur after the invite was sent. # being necessary this will occur after the invite was sent.
requires_join = True requires_join = True
else: else:
fake_requester = create_requester(user_id) fake_requester = create_requester(
user_id, authenticated_entity=self._server_name
)
# Choose whether to federate the new room. # Choose whether to federate the new room.
if not self.hs.config.registration.autocreate_auto_join_rooms_federated: if not self.hs.config.registration.autocreate_auto_join_rooms_federated:
@@ -359,7 +366,9 @@ class RegistrationHandler(BaseHandler):
# created it, then ensure the first user joins it. # created it, then ensure the first user joins it.
if requires_join: if requires_join:
await room_member_handler.update_membership( await room_member_handler.update_membership(
requester=create_requester(user_id), requester=create_requester(
user_id, authenticated_entity=self._server_name
),
target=UserID.from_string(user_id), target=UserID.from_string(user_id),
room_id=info["room_id"], room_id=info["room_id"],
# Since it was just created, there are no remote hosts. # Since it was just created, there are no remote hosts.
@@ -423,7 +432,8 @@ class RegistrationHandler(BaseHandler):
if requires_invite: if requires_invite:
await room_member_handler.update_membership( await room_member_handler.update_membership(
requester=create_requester( requester=create_requester(
self.hs.config.registration.auto_join_user_id self.hs.config.registration.auto_join_user_id,
authenticated_entity=self._server_name,
), ),
target=UserID.from_string(user_id), target=UserID.from_string(user_id),
room_id=room_id, room_id=room_id,
@@ -434,7 +444,9 @@ class RegistrationHandler(BaseHandler):
# Send the join. # Send the join.
await room_member_handler.update_membership( await room_member_handler.update_membership(
requester=create_requester(user_id), requester=create_requester(
user_id, authenticated_entity=self._server_name
),
target=UserID.from_string(user_id), target=UserID.from_string(user_id),
room_id=room_id, room_id=room_id,
remote_room_hosts=remote_room_hosts, remote_room_hosts=remote_room_hosts,
@@ -741,7 +753,7 @@ class RegistrationHandler(BaseHandler):
# up when the access token is saved, but that's quite an # up when the access token is saved, but that's quite an
# invasive change I'd rather do separately. # invasive change I'd rather do separately.
user_tuple = await self.store.get_user_by_access_token(token) user_tuple = await self.store.get_user_by_access_token(token)
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
await self.pusher_pool.add_pusher( await self.pusher_pool.add_pusher(
user_id=user_id, user_id=user_id,

View File

@@ -587,7 +587,7 @@ class RoomCreationHandler(BaseHandler):
""" """
user_id = requester.user.to_string() user_id = requester.user.to_string()
await self.auth.check_auth_blocking(user_id) await self.auth.check_auth_blocking(requester=requester)
if ( if (
self._server_notices_mxid is not None self._server_notices_mxid is not None
@@ -1250,7 +1250,9 @@ class RoomShutdownHandler:
400, "User must be our own: %s" % (new_room_user_id,) 400, "User must be our own: %s" % (new_room_user_id,)
) )
room_creator_requester = create_requester(new_room_user_id) room_creator_requester = create_requester(
new_room_user_id, authenticated_entity=requester_user_id
)
info, stream_id = await self._room_creation_handler.create_room( info, stream_id = await self._room_creation_handler.create_room(
room_creator_requester, room_creator_requester,
@@ -1290,7 +1292,9 @@ class RoomShutdownHandler:
try: try:
# Kick users from room # Kick users from room
target_requester = create_requester(user_id) target_requester = create_requester(
user_id, authenticated_entity=requester_user_id
)
_, stream_id = await self.room_member_handler.update_membership( _, stream_id = await self.room_member_handler.update_membership(
requester=target_requester, requester=target_requester,
target=target_requester.user, target=target_requester.user,

View File

@@ -961,6 +961,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
self.distributor = hs.get_distributor() self.distributor = hs.get_distributor()
self.distributor.declare("user_left_room") self.distributor.declare("user_left_room")
self._server_name = hs.hostname
async def _is_remote_room_too_complex( async def _is_remote_room_too_complex(
self, room_id: str, remote_room_hosts: List[str] self, room_id: str, remote_room_hosts: List[str]
@@ -1055,7 +1056,9 @@ class RoomMemberMasterHandler(RoomMemberHandler):
return event_id, stream_id return event_id, stream_id
# The room is too large. Leave. # The room is too large. Leave.
requester = types.create_requester(user, None, False, False, None) requester = types.create_requester(
user, authenticated_entity=self._server_name
)
await self.update_membership( await self.update_membership(
requester=requester, target=user, room_id=room_id, action="leave" requester=requester, target=user, room_id=room_id, action="leave"
) )

View File

@@ -31,6 +31,7 @@ from synapse.types import (
Collection, Collection,
JsonDict, JsonDict,
MutableStateMap, MutableStateMap,
Requester,
RoomStreamToken, RoomStreamToken,
StateMap, StateMap,
StreamToken, StreamToken,
@@ -260,6 +261,7 @@ class SyncHandler:
async def wait_for_sync_for_user( async def wait_for_sync_for_user(
self, self,
requester: Requester,
sync_config: SyncConfig, sync_config: SyncConfig,
since_token: Optional[StreamToken] = None, since_token: Optional[StreamToken] = None,
timeout: int = 0, timeout: int = 0,
@@ -273,7 +275,7 @@ class SyncHandler:
# not been exceeded (if not part of the group by this point, almost certain # not been exceeded (if not part of the group by this point, almost certain
# auth_blocking will occur) # auth_blocking will occur)
user_id = sync_config.user.to_string() user_id = sync_config.user.to_string()
await self.auth.check_auth_blocking(user_id) await self.auth.check_auth_blocking(requester=requester)
res = await self.response_cache.wrap( res = await self.response_cache.wrap(
sync_config.request_key, sync_config.request_key,

View File

@@ -55,6 +55,7 @@ class SynapseRequest(Request):
self.site = channel.site self.site = channel.site
self._channel = channel # this is used by the tests self._channel = channel # this is used by the tests
self.authenticated_entity = None self.authenticated_entity = None
self.target_user = None
self.start_time = 0.0 self.start_time = 0.0
# we can't yet create the logcontext, as we don't know the method. # we can't yet create the logcontext, as we don't know the method.
@@ -269,6 +270,11 @@ class SynapseRequest(Request):
if authenticated_entity is not None and isinstance(authenticated_entity, bytes): if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
authenticated_entity = authenticated_entity.decode("utf-8", "replace") authenticated_entity = authenticated_entity.decode("utf-8", "replace")
if self.target_user:
authenticated_entity = "{} as {}".format(
authenticated_entity, self.target_user,
)
# ...or could be raw utf-8 bytes in the User-Agent header. # ...or could be raw utf-8 bytes in the User-Agent header.
# N.B. if you don't do this, the logger explodes cryptically # N.B. if you don't do this, the logger explodes cryptically
# with maximum recursion trying to log errors about # with maximum recursion trying to log errors about

View File

@@ -49,6 +49,7 @@ class ModuleApi:
self._store = hs.get_datastore() self._store = hs.get_datastore()
self._auth = hs.get_auth() self._auth = hs.get_auth()
self._auth_handler = auth_handler self._auth_handler = auth_handler
self._server_name = hs.hostname
# We expose these as properties below in order to attach a helpful docstring. # We expose these as properties below in order to attach a helpful docstring.
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
@@ -336,7 +337,9 @@ class ModuleApi:
SynapseError if the event was not allowed. SynapseError if the event was not allowed.
""" """
# Create a requester object # Create a requester object
requester = create_requester(event_dict["sender"]) requester = create_requester(
event_dict["sender"], authenticated_entity=self._server_name
)
# Create and send the event # Create and send the event
( (

View File

@@ -55,6 +55,7 @@ from synapse.rest.admin.users import (
UserRestServletV2, UserRestServletV2,
UsersRestServlet, UsersRestServlet,
UsersRestServletV2, UsersRestServletV2,
UserTokenRestServlet,
WhoisRestServlet, WhoisRestServlet,
) )
from synapse.types import RoomStreamToken from synapse.types import RoomStreamToken
@@ -216,6 +217,7 @@ def register_servlets(hs, http_server):
VersionServlet(hs).register(http_server) VersionServlet(hs).register(http_server)
UserAdminServlet(hs).register(http_server) UserAdminServlet(hs).register(http_server)
UserMembershipRestServlet(hs).register(http_server) UserMembershipRestServlet(hs).register(http_server)
UserTokenRestServlet(hs).register(http_server)
UserRestServletV2(hs).register(http_server) UserRestServletV2(hs).register(http_server)
UsersRestServletV2(hs).register(http_server) UsersRestServletV2(hs).register(http_server)
DeviceRestServlet(hs).register(http_server) DeviceRestServlet(hs).register(http_server)

View File

@@ -309,7 +309,9 @@ class JoinRoomAliasServlet(RestServlet):
400, "%s was not legal room ID or room alias" % (room_identifier,) 400, "%s was not legal room ID or room alias" % (room_identifier,)
) )
fake_requester = create_requester(target_user) fake_requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
)
# send invite if room has "JoinRules.INVITE" # send invite if room has "JoinRules.INVITE"
room_state = await self.state_handler.get_current_state(room_id) room_state = await self.state_handler.get_current_state(room_id)

View File

@@ -16,6 +16,7 @@ import hashlib
import hmac import hmac
import logging import logging
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -35,6 +36,9 @@ from synapse.rest.admin._base import (
) )
from synapse.types import UserID from synapse.types import UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -708,3 +712,42 @@ class UserMembershipRestServlet(RestServlet):
ret = {"joined_rooms": list(room_ids), "total": len(room_ids)} ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
return 200, ret return 200, ret
class UserTokenRestServlet(RestServlet):
"""An admin API for logging in as a user.
"""
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
async def on_PUT(self, request, user_id):
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
if not self.hs.is_mine_id(user_id):
raise SynapseError(400, "Only local users can be logged in as")
body = parse_json_object_from_request(request)
valid_until_ms = body.get("valid_until_ms")
if valid_until_ms and not isinstance(valid_until_ms, int):
raise SynapseError(400, "'valid_until_ms' parameter must be an int")
if auth_user.to_string() == user_id:
raise SynapseError(400, "Cannot use admin API to login as self")
token = await self.auth_handler.get_access_token_for_user_id(
user_id=auth_user.to_string(),
device_id=None,
valid_until_ms=valid_until_ms,
puppets_user_id=user_id,
)
return 200, {"access_token": token}

View File

@@ -171,6 +171,7 @@ class SyncRestServlet(RestServlet):
) )
with context: with context:
sync_result = await self.sync_handler.wait_for_sync_for_user( sync_result = await self.sync_handler.wait_for_sync_for_user(
requester,
sync_config, sync_config,
since_token=since_token, since_token=since_token,
timeout=timeout, timeout=timeout,

View File

@@ -39,6 +39,7 @@ class ServerNoticesManager:
self._room_member_handler = hs.get_room_member_handler() self._room_member_handler = hs.get_room_member_handler()
self._event_creation_handler = hs.get_event_creation_handler() self._event_creation_handler = hs.get_event_creation_handler()
self._is_mine_id = hs.is_mine_id self._is_mine_id = hs.is_mine_id
self._server_name = hs.hostname
self._notifier = hs.get_notifier() self._notifier = hs.get_notifier()
self.server_notices_mxid = self._config.server_notices_mxid self.server_notices_mxid = self._config.server_notices_mxid
@@ -72,7 +73,9 @@ class ServerNoticesManager:
await self.maybe_invite_user_to_room(user_id, room_id) await self.maybe_invite_user_to_room(user_id, room_id)
system_mxid = self._config.server_notices_mxid system_mxid = self._config.server_notices_mxid
requester = create_requester(system_mxid) requester = create_requester(
system_mxid, authenticated_entity=self._server_name
)
logger.info("Sending server notice to %s", user_id) logger.info("Sending server notice to %s", user_id)
@@ -145,7 +148,9 @@ class ServerNoticesManager:
"avatar_url": self._config.server_notices_mxid_avatar_url, "avatar_url": self._config.server_notices_mxid_avatar_url,
} }
requester = create_requester(self.server_notices_mxid) requester = create_requester(
self.server_notices_mxid, authenticated_entity=self._server_name
)
info, _ = await self._room_creation_handler.create_room( info, _ = await self._room_creation_handler.create_room(
requester, requester,
config={ config={
@@ -174,7 +179,9 @@ class ServerNoticesManager:
user_id: The ID of the user to invite. user_id: The ID of the user to invite.
room_id: The ID of the room to invite the user to. room_id: The ID of the room to invite the user to.
""" """
requester = create_requester(self.server_notices_mxid) requester = create_requester(
self.server_notices_mxid, authenticated_entity=self._server_name
)
# Check whether the user has already joined or been invited to this room. If # Check whether the user has already joined or been invited to this room. If
# that's the case, there is no need to re-invite them. # that's the case, there is no need to re-invite them.

View File

@@ -146,7 +146,6 @@ class DataStore(
db_conn, "e2e_cross_signing_keys", "stream_id" db_conn, "e2e_cross_signing_keys", "stream_id"
) )
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id") self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id") self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id") self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")

View File

@@ -16,29 +16,53 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re import re
from typing import Any, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import attr
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Cursor from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import UserID from synapse.types import UserID
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.server import HomeServer
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000 THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class RegistrationWorkerStore(SQLBaseStore): @attr.s(frozen=True, slots=True)
def __init__(self, database: DatabasePool, db_conn, hs): class TokenLookupResult:
"""Result of looking up an access token.
"""
user_id = attr.ib(type=str)
is_guest = attr.ib(type=bool, default=False)
shadow_banned = attr.ib(type=bool, default=False)
token_id = attr.ib(type=Optional[int], default=None)
device_id = attr.ib(type=Optional[str], default=None)
valid_until_ms = attr.ib(type=Optional[int], default=None)
token_owner = attr.ib(type=str)
@token_owner.default
def _default_token_owner(self):
return self.user_id
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.config = hs.config self.config = hs.config
self.clock = hs.get_clock()
# Note: we don't check this sequence for consistency as we'd have to # Note: we don't check this sequence for consistency as we'd have to
# call `find_max_generated_user_id_localpart` each time, which is # call `find_max_generated_user_id_localpart` each time, which is
@@ -55,7 +79,7 @@ class RegistrationWorkerStore(SQLBaseStore):
# Create a background job for culling expired 3PID validity tokens # Create a background job for culling expired 3PID validity tokens
if hs.config.run_background_tasks: if hs.config.run_background_tasks:
self.clock.looping_call( self._clock.looping_call(
self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
) )
@@ -92,13 +116,13 @@ class RegistrationWorkerStore(SQLBaseStore):
if not info: if not info:
return False return False
now = self.clock.time_msec() now = self._clock.time_msec()
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
return is_trial return is_trial
@cached() @cached()
async def get_user_by_access_token(self, token: str) -> Optional[dict]: async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
"""Get a user from the given access token. """Get a user from the given access token.
Args: Args:
@@ -257,7 +281,7 @@ class RegistrationWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_users_expiring_soon", "get_users_expiring_soon",
select_users_txn, select_users_txn,
self.clock.time_msec(), self._clock.time_msec(),
self.config.account_validity.renew_at, self.config.account_validity.renew_at,
) )
@@ -327,19 +351,24 @@ class RegistrationWorkerStore(SQLBaseStore):
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn) await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
def _query_for_auth(self, txn, token): def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
sql = ( sql = """
"SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id," SELECT users.name as user_id,
" access_tokens.device_id, access_tokens.valid_until_ms" users.is_guest,
" FROM users" users.shadow_banned,
" INNER JOIN access_tokens on users.name = access_tokens.user_id" access_tokens.id as token_id,
" WHERE token = ?" access_tokens.device_id,
) access_tokens.valid_until_ms,
access_tokens.user_id as token_owner
FROM users
INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
WHERE token = ?
"""
txn.execute(sql, (token,)) txn.execute(sql, (token,))
rows = self.db_pool.cursor_to_dict(txn) rows = self.db_pool.cursor_to_dict(txn)
if rows: if rows:
return rows[0] return TokenLookupResult(**rows[0])
return None return None
@@ -803,7 +832,7 @@ class RegistrationWorkerStore(SQLBaseStore):
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
"cull_expired_threepid_validation_tokens", "cull_expired_threepid_validation_tokens",
cull_expired_threepid_validation_tokens_txn, cull_expired_threepid_validation_tokens_txn,
self.clock.time_msec(), self._clock.time_msec(),
) )
@wrap_as_background_process("account_validity_set_expiration_dates") @wrap_as_background_process("account_validity_set_expiration_dates")
@@ -890,10 +919,10 @@ class RegistrationWorkerStore(SQLBaseStore):
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.clock = hs.get_clock() self._clock = hs.get_clock()
self.config = hs.config self.config = hs.config
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
@@ -1016,19 +1045,63 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
return 1 return 1
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
"""Set the `deactivated` property for the provided user to the provided value.
class RegistrationStore(RegistrationBackgroundUpdateStore): Args:
def __init__(self, database: DatabasePool, db_conn, hs): user_id: The ID of the user to set the status for.
deactivated: The value to set for `deactivated`.
"""
await self.db_pool.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
deactivated,
)
def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
)
return res if res else False
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
async def add_access_token_to_user( async def add_access_token_to_user(
self, self,
user_id: str, user_id: str,
token: str, token: str,
device_id: Optional[str], device_id: Optional[str],
valid_until_ms: Optional[int], valid_until_ms: Optional[int],
puppets_user_id: Optional[str] = None,
) -> int: ) -> int:
"""Adds an access token for the given user. """Adds an access token for the given user.
@@ -1052,6 +1125,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"token": token, "token": token,
"device_id": device_id, "device_id": device_id,
"valid_until_ms": valid_until_ms, "valid_until_ms": valid_until_ms,
"puppets_user_id": puppets_user_id,
}, },
desc="add_access_token_to_user", desc="add_access_token_to_user",
) )
@@ -1138,19 +1212,19 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
def _register_user( def _register_user(
self, self,
txn, txn,
user_id, user_id: str,
password_hash, password_hash: Optional[str],
was_guest, was_guest: bool,
make_guest, make_guest: bool,
appservice_id, appservice_id: Optional[str],
create_profile_with_displayname, create_profile_with_displayname: Optional[str],
admin, admin: bool,
user_type, user_type: Optional[str],
shadow_banned, shadow_banned: bool,
): ):
user_id_obj = UserID.from_string(user_id) user_id_obj = UserID.from_string(user_id)
now = int(self.clock.time()) now = int(self._clock.time())
try: try:
if was_guest: if was_guest:
@@ -1374,18 +1448,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
await self.db_pool.runInteraction("delete_access_token", f) await self.db_pool.runInteraction("delete_access_token", f)
@cached()
async def is_guest(self, user_id: str) -> bool:
res = await self.db_pool.simple_select_one_onecol(
table="users",
keyvalues={"name": user_id},
retcol="is_guest",
allow_none=True,
desc="is_guest",
)
return res if res else False
async def add_user_pending_deactivation(self, user_id: str) -> None: async def add_user_pending_deactivation(self, user_id: str) -> None:
""" """
Adds a user to the table of users who need to be parted from all the rooms they're Adds a user to the table of users who need to be parted from all the rooms they're
@@ -1479,7 +1541,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
txn, txn,
table="threepid_validation_session", table="threepid_validation_session",
keyvalues={"session_id": session_id}, keyvalues={"session_id": session_id},
updatevalues={"validated_at": self.clock.time_msec()}, updatevalues={"validated_at": self._clock.time_msec()},
) )
return next_link return next_link
@@ -1547,35 +1609,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
start_or_continue_validation_session_txn, start_or_continue_validation_session_txn,
) )
async def set_user_deactivated_status(
self, user_id: str, deactivated: bool
) -> None:
"""Set the `deactivated` property for the provided user to the provided value.
Args:
user_id: The ID of the user to set the status for.
deactivated: The value to set for `deactivated`.
"""
await self.db_pool.runInteraction(
"set_user_deactivated_status",
self.set_user_deactivated_status_txn,
user_id,
deactivated,
)
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self.db_pool.simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
txn.call_after(self.is_guest.invalidate, (user_id,))
def find_max_generated_user_id_localpart(cur: Cursor) -> int: def find_max_generated_user_id_localpart(cur: Cursor) -> int:
""" """

View File

@@ -0,0 +1,17 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Whether the access token is an admin token for controlling another user.
ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT;

View File

@@ -74,6 +74,7 @@ class Requester(
"shadow_banned", "shadow_banned",
"device_id", "device_id",
"app_service", "app_service",
"authenticated_entity",
], ],
) )
): ):
@@ -104,6 +105,7 @@ class Requester(
"shadow_banned": self.shadow_banned, "shadow_banned": self.shadow_banned,
"device_id": self.device_id, "device_id": self.device_id,
"app_server_id": self.app_service.id if self.app_service else None, "app_server_id": self.app_service.id if self.app_service else None,
"authenticated_entity": self.authenticated_entity,
} }
@staticmethod @staticmethod
@@ -129,6 +131,7 @@ class Requester(
shadow_banned=input["shadow_banned"], shadow_banned=input["shadow_banned"],
device_id=input["device_id"], device_id=input["device_id"],
app_service=appservice, app_service=appservice,
authenticated_entity=input["authenticated_entity"],
) )
@@ -139,6 +142,7 @@ def create_requester(
shadow_banned=False, shadow_banned=False,
device_id=None, device_id=None,
app_service=None, app_service=None,
authenticated_entity=None,
): ):
""" """
Create a new ``Requester`` object Create a new ``Requester`` object
@@ -151,14 +155,27 @@ def create_requester(
shadow_banned (bool): True if the user making this request is shadow-banned. shadow_banned (bool): True if the user making this request is shadow-banned.
device_id (str|None): device_id which was set at authentication time device_id (str|None): device_id which was set at authentication time
app_service (ApplicationService|None): the AS requesting on behalf of the user app_service (ApplicationService|None): the AS requesting on behalf of the user
authenticated_entity: The entity that authenticatd when making the request,
this is different than the user_id when an admin user or the server is
"puppeting" the user.
Returns: Returns:
Requester Requester
""" """
if not isinstance(user_id, UserID): if not isinstance(user_id, UserID):
user_id = UserID.from_string(user_id) user_id = UserID.from_string(user_id)
if authenticated_entity is None:
authenticated_entity = user_id.to_string()
return Requester( return Requester(
user_id, access_token_id, is_guest, shadow_banned, device_id, app_service user_id,
access_token_id,
is_guest,
shadow_banned,
device_id,
app_service,
authenticated_entity,
) )

View File

@@ -29,6 +29,7 @@ from synapse.api.errors import (
MissingClientTokenError, MissingClientTokenError,
ResourceLimitError, ResourceLimitError,
) )
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import UserID from synapse.types import UserID
from tests import unittest from tests import unittest
@@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_user_by_req_user_valid_token(self): def test_get_user_by_req_user_valid_token(self):
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"} user_info = TokenLookupResult(
user_id=self.test_user, token_id=5, device_id="device"
)
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info) return_value=defer.succeed(user_info)
) )
@@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN") self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
def test_get_user_by_req_user_missing_token(self): def test_get_user_by_req_user_missing_token(self):
user_info = {"name": self.test_user, "token_id": "ditto"} user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = Mock(
return_value=defer.succeed(user_info) return_value=defer.succeed(user_info)
) )
@@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
def test_get_user_from_macaroon(self): def test_get_user_from_macaroon(self):
self.store.get_user_by_access_token = Mock( self.store.get_user_by_access_token = Mock(
return_value=defer.succeed( return_value=defer.succeed(
{"name": "@baldrick:matrix.org", "device_id": "device"} TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
) )
) )
@@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred( user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(macaroon.serialize()) self.auth.get_user_by_access_token(macaroon.serialize())
) )
user = user_info["user"] self.assertEqual(user_id, user_info.user_id)
self.assertEqual(UserID.from_string(user_id), user)
# TODO: device_id should come from the macaroon, but currently comes # TODO: device_id should come from the macaroon, but currently comes
# from the db. # from the db.
self.assertEqual(user_info["device_id"], "device") self.assertEqual(user_info.device_id, "device")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_get_guest_user_from_macaroon(self): def test_get_guest_user_from_macaroon(self):
@@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase):
user_info = yield defer.ensureDeferred( user_info = yield defer.ensureDeferred(
self.auth.get_user_by_access_token(serialized) self.auth.get_user_by_access_token(serialized)
) )
user = user_info["user"] self.assertEqual(user_id, user_info.user_id)
is_guest = user_info["is_guest"] self.assertTrue(user_info.is_guest)
self.assertEqual(UserID.from_string(user_id), user)
self.assertTrue(is_guest)
self.store.get_user_by_id.assert_called_with(user_id) self.store.get_user_by_id.assert_called_with(user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase):
if token != tok: if token != tok:
return defer.succeed(None) return defer.succeed(None)
return defer.succeed( return defer.succeed(
{ TokenLookupResult(
"name": USER_ID, user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
"is_guest": False, )
"token_id": 1234,
"device_id": "DEVICE",
}
) )
self.store.get_user_by_access_token = get_user self.store.get_user_by_access_token = get_user

View File

@@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
# make sure that our device ID has changed # make sure that our device ID has changed
user_info = self.get_success(self.auth.get_user_by_access_token(access_token)) user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
self.assertEqual(user_info["device_id"], retrieved_device_id) self.assertEqual(user_info.device_id, retrieved_device_id)
# make sure the device has the display name that was set from the login # make sure the device has the display name that was set from the login
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id)) res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))

View File

@@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.info = self.get_success( self.info = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token,) self.hs.get_datastore().get_user_by_access_token(self.access_token,)
) )
self.token_id = self.info["token_id"] self.token_id = self.info.token_id
self.requester = create_requester(self.user_id, access_token_id=self.token_id) self.requester = create_requester(self.user_id, access_token_id=self.token_id)

View File

@@ -16,7 +16,7 @@
from synapse.api.errors import Codes, ResourceLimitError from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
from synapse.handlers.sync import SyncConfig from synapse.handlers.sync import SyncConfig
from synapse.types import UserID from synapse.types import UserID, create_requester
import tests.unittest import tests.unittest
import tests.utils import tests.utils
@@ -38,6 +38,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
user_id1 = "@user1:test" user_id1 = "@user1:test"
user_id2 = "@user2:test" user_id2 = "@user2:test"
sync_config = self._generate_sync_config(user_id1) sync_config = self._generate_sync_config(user_id1)
requester = create_requester(user_id1)
self.reactor.advance(100) # So we get not 0 time self.reactor.advance(100) # So we get not 0 time
self.auth_blocking._limit_usage_by_mau = True self.auth_blocking._limit_usage_by_mau = True
@@ -45,21 +46,26 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Check that the happy case does not throw errors # Check that the happy case does not throw errors
self.get_success(self.store.upsert_monthly_active_user(user_id1)) self.get_success(self.store.upsert_monthly_active_user(user_id1))
self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config)) self.get_success(
self.sync_handler.wait_for_sync_for_user(requester, sync_config)
)
# Test that global lock works # Test that global lock works
self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled = True
e = self.get_failure( e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError self.sync_handler.wait_for_sync_for_user(requester, sync_config),
ResourceLimitError,
) )
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
self.auth_blocking._hs_disabled = False self.auth_blocking._hs_disabled = False
sync_config = self._generate_sync_config(user_id2) sync_config = self._generate_sync_config(user_id2)
requester = create_requester(user_id2)
e = self.get_failure( e = self.get_failure(
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError self.sync_handler.wait_for_sync_for_user(requester, sync_config),
ResourceLimitError,
) )
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)

View File

@@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(self.access_token) self.hs.get_datastore().get_user_by_access_token(self.access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.pusher = self.get_success( self.pusher = self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(

View File

@@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(
@@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(
@@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(
@@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(
@@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase):
user_tuple = self.get_success( user_tuple = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_tuple["token_id"] token_id = user_tuple.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(

View File

@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
user_dict = self.get_success( user_dict = self.get_success(
self.hs.get_datastore().get_user_by_access_token(access_token) self.hs.get_datastore().get_user_by_access_token(access_token)
) )
token_id = user_dict["token_id"] token_id = user_dict.token_id
self.get_success( self.get_success(
self.hs.get_pusherpool().add_pusher( self.hs.get_pusherpool().add_pusher(

View File

@@ -23,8 +23,8 @@ from mock import Mock
import synapse.rest.admin import synapse.rest.admin
from synapse.api.constants import UserTypes from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, logout, room
from synapse.rest.client.v2_alpha import sync from synapse.rest.client.v2_alpha import devices, sync
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
@@ -1101,3 +1101,244 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(number_rooms, channel.json_body["total"]) self.assertEqual(number_rooms, channel.json_body["total"])
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"])) self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
class UserTokenRestTestCase(unittest.HomeserverTestCase):
"""Test for /_synapse/admin/v1/users/<user>/login
"""
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
sync.register_servlets,
room.register_servlets,
devices.register_servlets,
logout.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
self.other_user_tok = self.login("user", "pass")
self.url = "/_synapse/admin/v1/users/%s/login" % urllib.parse.quote(
self.other_user
)
def _get_token(self) -> str:
request, channel = self.make_request(
"PUT", self.url, b"{}", access_token=self.admin_user_tok
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
return channel.json_body["access_token"]
def test_no_auth(self):
"""Try to login as a user without authentication.
"""
request, channel = self.make_request("PUT", self.url, b"{}")
self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_not_admin(self):
"""Try to login as a user as a non-admin user.
"""
request, channel = self.make_request(
"PUT", self.url, b"{}", access_token=self.other_user_tok
)
self.render(request)
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
def test_send_event(self):
"""Test that sending event as a user works.
"""
# Create a room.
room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
# Login in as the user
puppet_token = self._get_token()
# Test that sending works, and generates the event as the right user.
resp = self.helper.send_event(room_id, "com.example.test", tok=puppet_token)
event_id = resp["event_id"]
event = self.get_success(self.store.get_event(event_id))
self.assertEqual(event.sender, self.other_user)
def test_devices(self):
"""Tests that logging in as a user doesn't create a new device for them.
"""
# Login in as the user
self._get_token()
# Check that we don't see a new device in our devices list
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# We should only see the one device (from the login in `prepare`)
self.assertEqual(len(channel.json_body["devices"]), 1)
def test_logout(self):
"""Test that calling `/logout` with the token works.
"""
# Login in as the user
puppet_token = self._get_token()
# Test that we can successfully make a request
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=puppet_token
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Logout with the puppet token
request, channel = self.make_request(
"POST", "logout", b"{}", access_token=puppet_token
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# The puppet token should no longer work
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=puppet_token
)
self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
# .. but the real user's tokens should still work
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
def test_user_logout_all(self):
"""Tests that the target user calling `/logout/all` does *not* expire
the token.
"""
# Login in as the user
puppet_token = self._get_token()
# Test that we can successfully make a request
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=puppet_token
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Logout all with the real user token
request, channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.other_user_tok
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# The puppet token should still work
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=puppet_token
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# .. but the real user's tokens shouldn't
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
def test_admin_logout_all(self):
"""Tests that the admin user calling `/logout/all` does expire the
token.
"""
# Login in as the user
puppet_token = self._get_token()
# Test that we can successfully make a request
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=puppet_token
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Logout all with the admin user token
request, channel = self.make_request(
"POST", "logout/all", b"{}", access_token=self.admin_user_tok
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# The puppet token should no longer work
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=puppet_token
)
self.render(request)
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
# .. but the real user's tokens should still work
request, channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
@unittest.override_config(
{
"public_baseurl": "https://example.org/",
"user_consent": {
"version": "1.0",
"policy_name": "My Cool Privacy Policy",
"template_dir": "/",
"require_at_registration": True,
"block_events_error": "You should accept the policy",
},
"form_secret": "123secret",
}
)
def test_consent(self):
"""Test that sending a message is not subject to the privacy policies.
"""
# Have the admin user accept the terms.
self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0"))
# First, cheekily accept the terms and create a room
self.get_success(self.store.user_set_consent_version(self.other_user, "1.0"))
room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
self.helper.send_event(room_id, "com.example.test", tok=self.other_user_tok)
# Now unaccept it and check that we can't send an event
self.get_success(self.store.user_set_consent_version(self.other_user, "0.0"))
self.helper.send_event(
room_id, "com.example.test", tok=self.other_user_tok, expect_code=403
)
# Login in as the user
puppet_token = self._get_token()
# Sending an event on their behalf should work fine
self.helper.send_event(room_id, "com.example.test", tok=puppet_token)
@override_config(
{"limit_usage_by_mau": True, "max_mau_value": 1, "mau_trial_days": 0}
)
def test_mau_limit(self):
# Create a room as the admin user. This will bump the monthly active users to 1.
room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
# Trying to join as the other user should fail.
self.helper.join(
room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403
)
# Logging in as the other user and joining a room should work, even
# though they should be denied.
puppet_token = self._get_token()
self.helper.join(room_id, user=self.other_user, tok=puppet_token)

View File

@@ -22,7 +22,7 @@ import synapse.rest.admin
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, room
from synapse.storage import prepare_database from synapse.storage import prepare_database
from synapse.types import Requester, UserID from synapse.types import UserID, create_requester
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
# Create a test user and room # Create a test user and room
self.user = UserID("alice", "test") self.user = UserID("alice", "test")
self.requester = Requester(self.user, None, False, False, None, None) self.requester = create_requester(self.user)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"] self.room_id = info["room_id"]
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
# Create a test user and room # Create a test user and room
self.user = UserID.from_string(self.register_user("user1", "password")) self.user = UserID.from_string(self.register_user("user1", "password"))
self.token1 = self.login("user1", "password") self.token1 = self.login("user1", "password")
self.requester = Requester(self.user, None, False, False, None, None) self.requester = create_requester(self.user)
info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
self.room_id = info["room_id"] self.room_id = info["room_id"]
self.event_creator = homeserver.get_event_creation_handler() self.event_creator = homeserver.get_event_creation_handler()

View File

@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
from synapse.metrics import REGISTRY, generate_latest from synapse.metrics import REGISTRY, generate_latest
from synapse.types import Requester, UserID from synapse.types import UserID, create_requester
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
room_creator = self.hs.get_room_creation_handler() room_creator = self.hs.get_room_creation_handler()
user = UserID("alice", "test") user = UserID("alice", "test")
requester = Requester(user, None, False, False, None, None) requester = create_requester(user)
# Real events, forward extremities # Real events, forward extremities
events = [(3, 2), (6, 2), (4, 6)] events = [(3, 2), (6, 2), (4, 6)]

View File

@@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase):
self.store.get_user_by_access_token(self.tokens[1]) self.store.get_user_by_access_token(self.tokens[1])
) )
self.assertDictContainsSubset( self.assertEqual(result.user_id, self.user_id)
{"name": self.user_id, "device_id": self.device_id}, result self.assertEqual(result.device_id, self.device_id)
) self.assertIsNotNone(result.token_id)
self.assertTrue("token_id" in result)
@defer.inlineCallbacks @defer.inlineCallbacks
def test_user_delete_access_tokens(self): def test_user_delete_access_tokens(self):
@@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
user = yield defer.ensureDeferred( user = yield defer.ensureDeferred(
self.store.get_user_by_access_token(self.tokens[0]) self.store.get_user_by_access_token(self.tokens[0])
) )
self.assertEqual(self.user_id, user["name"]) self.assertEqual(self.user_id, user.user_id)
# now delete the rest # now delete the rest
yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id)) yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))

View File

@@ -19,7 +19,7 @@ from unittest.mock import Mock
from synapse.api.constants import Membership from synapse.api.constants import Membership
from synapse.rest.admin import register_servlets_for_client_rest_resource from synapse.rest.admin import register_servlets_for_client_rest_resource
from synapse.rest.client.v1 import login, room from synapse.rest.client.v1 import login, room
from synapse.types import Requester, UserID from synapse.types import UserID, create_requester
from tests import unittest from tests import unittest
from tests.test_utils import event_injection from tests.test_utils import event_injection
@@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
# Now let's create a room, which will insert a membership # Now let's create a room, which will insert a membership
user = UserID("alice", "test") user = UserID("alice", "test")
requester = Requester(user, None, False, False, None, None) requester = create_requester(user)
self.get_success(self.room_creator.create_room(requester, {})) self.get_success(self.room_creator.create_room(requester, {}))
# Register the background update to run again. # Register the background update to run again.

View File

@@ -20,7 +20,7 @@ from twisted.internet.defer import succeed
from synapse.api.errors import FederationError from synapse.api.errors import FederationError
from synapse.events import make_event_from_dict from synapse.events import make_event_from_dict
from synapse.logging.context import LoggingContext from synapse.logging.context import LoggingContext
from synapse.types import Requester, UserID from synapse.types import UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
@@ -43,7 +43,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
) )
user_id = UserID("us", "test") user_id = UserID("us", "test")
our_user = Requester(user_id, None, False, False, None, None) our_user = create_requester(user_id)
room_creator = self.homeserver.get_room_creation_handler() room_creator = self.homeserver.get_room_creation_handler()
self.room_id = self.get_success( self.room_id = self.get_success(
room_creator.create_room( room_creator.create_room(

View File

@@ -169,6 +169,7 @@ class StateTestCase(unittest.TestCase):
"get_state_handler", "get_state_handler",
"get_clock", "get_clock",
"get_state_resolution_handler", "get_state_resolution_handler",
"hostname",
] ]
) )
hs.config = default_config("tesths", True) hs.config = default_config("tesths", True)

View File

@@ -44,7 +44,7 @@ from synapse.logging.context import (
set_current_context, set_current_context,
) )
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import Requester, UserID, create_requester from synapse.types import UserID, create_requester
from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import ( from tests.server import (
@@ -627,7 +627,7 @@ class HomeserverTestCase(TestCase):
""" """
event_creator = self.hs.get_event_creation_handler() event_creator = self.hs.get_event_creation_handler()
secrets = self.hs.get_secrets() secrets = self.hs.get_secrets()
requester = Requester(user, None, False, False, None, None) requester = create_requester(user)
event, context = self.get_success( event, context = self.get_success(
event_creator.create_event( event_creator.create_event(