Compare commits
10 Commits
erikj/dock
...
erikj/pupp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f42fa8b15b | ||
|
|
d8902d4bd5 | ||
|
|
a4b03bf183 | ||
|
|
8d08cf75f4 | ||
|
|
9edb5b369e | ||
|
|
c238a54357 | ||
|
|
8620b27113 | ||
|
|
f0bcf6f578 | ||
|
|
03076254e3 | ||
|
|
d9c0b19b30 |
@@ -341,6 +341,41 @@ The following fields are returned in the JSON response body:
|
|||||||
- ``total`` - Number of rooms.
|
- ``total`` - Number of rooms.
|
||||||
|
|
||||||
|
|
||||||
|
Login as a user
|
||||||
|
===============
|
||||||
|
|
||||||
|
Get an access token that can be used to authenticate as that user. Useful for
|
||||||
|
when admins wish to do actions on behalf of a user.
|
||||||
|
|
||||||
|
The API is::
|
||||||
|
|
||||||
|
PUT /_synapse/admin/v1/users/<user_id>/login
|
||||||
|
{}
|
||||||
|
|
||||||
|
An optional ``valid_until_ms`` field can be specified in the request body as an
|
||||||
|
integer timestamp that specifies when the token should expire. By default tokens
|
||||||
|
do not expire.
|
||||||
|
|
||||||
|
A response body like the following is returned:
|
||||||
|
|
||||||
|
.. code:: json
|
||||||
|
|
||||||
|
{
|
||||||
|
"access_token": "<opaque_access_token_string>"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
This API does *not* generate a new device for the user, and so will not appear
|
||||||
|
their ``/devices`` list, and in general the target user should not be able to
|
||||||
|
tell they have been logged in as.
|
||||||
|
|
||||||
|
To expire the token call the standard ``/logout`` API with the token.
|
||||||
|
|
||||||
|
Note: The token will expire if the *admin* user calls ``/logout/all`` from any
|
||||||
|
of their devices, but the token will *not* expire if the target user does the
|
||||||
|
same.
|
||||||
|
|
||||||
|
|
||||||
User devices
|
User devices
|
||||||
============
|
============
|
||||||
|
|
||||||
|
|||||||
1
mypy.ini
1
mypy.ini
@@ -55,6 +55,7 @@ files =
|
|||||||
synapse/spam_checker_api,
|
synapse/spam_checker_api,
|
||||||
synapse/state,
|
synapse/state,
|
||||||
synapse/storage/databases/main/events.py,
|
synapse/storage/databases/main/events.py,
|
||||||
|
synapse/storage/databases/main/registration.py,
|
||||||
synapse/storage/databases/main/stream.py,
|
synapse/storage/databases/main/stream.py,
|
||||||
synapse/storage/databases/main/ui_auth.py,
|
synapse/storage/databases/main/ui_auth.py,
|
||||||
synapse/storage/database.py,
|
synapse/storage/database.py,
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from synapse.api.errors import (
|
|||||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.logging import opentracing as opentracing
|
from synapse.logging import opentracing as opentracing
|
||||||
|
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||||
from synapse.types import StateMap, UserID
|
from synapse.types import StateMap, UserID
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
@@ -194,6 +195,7 @@ class Auth:
|
|||||||
if user_id:
|
if user_id:
|
||||||
request.authenticated_entity = user_id
|
request.authenticated_entity = user_id
|
||||||
opentracing.set_tag("authenticated_entity", user_id)
|
opentracing.set_tag("authenticated_entity", user_id)
|
||||||
|
opentracing.set_tag("target_user", user_id)
|
||||||
opentracing.set_tag("appservice_id", app_service.id)
|
opentracing.set_tag("appservice_id", app_service.id)
|
||||||
|
|
||||||
if ip_addr and self._track_appservice_user_ips:
|
if ip_addr and self._track_appservice_user_ips:
|
||||||
@@ -210,26 +212,25 @@ class Auth:
|
|||||||
user_info = await self.get_user_by_access_token(
|
user_info = await self.get_user_by_access_token(
|
||||||
access_token, rights, allow_expired=allow_expired
|
access_token, rights, allow_expired=allow_expired
|
||||||
)
|
)
|
||||||
user = user_info["user"]
|
user = UserID.from_string(user_info.user_id)
|
||||||
token_id = user_info["token_id"]
|
token_id = user_info.token_id
|
||||||
is_guest = user_info["is_guest"]
|
is_guest = user_info.is_guest
|
||||||
shadow_banned = user_info["shadow_banned"]
|
shadow_banned = user_info.shadow_banned
|
||||||
|
|
||||||
# Deny the request if the user account has expired.
|
# Deny the request if the user account has expired.
|
||||||
if self._account_validity.enabled and not allow_expired:
|
if self._account_validity.enabled and not allow_expired:
|
||||||
user_id = user.to_string()
|
if await self.store.is_account_expired(
|
||||||
if await self.store.is_account_expired(user_id, self.clock.time_msec()):
|
user_info.user_id, self.clock.time_msec()
|
||||||
|
):
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
|
403, "User account has expired", errcode=Codes.EXPIRED_ACCOUNT
|
||||||
)
|
)
|
||||||
|
|
||||||
# device_id may not be present if get_user_by_access_token has been
|
device_id = user_info.device_id
|
||||||
# stubbed out.
|
|
||||||
device_id = user_info.get("device_id")
|
|
||||||
|
|
||||||
if user and access_token and ip_addr:
|
if access_token and ip_addr:
|
||||||
await self.store.insert_client_ip(
|
await self.store.insert_client_ip(
|
||||||
user_id=user.to_string(),
|
user_id=user_info.token_owner,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
ip=ip_addr,
|
ip=ip_addr,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
@@ -243,8 +244,10 @@ class Auth:
|
|||||||
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
|
errcode=Codes.GUEST_ACCESS_FORBIDDEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
request.authenticated_entity = user.to_string()
|
request.authenticated_entity = user_info.token_owner
|
||||||
opentracing.set_tag("authenticated_entity", user.to_string())
|
request.target_user = user_info.user_id
|
||||||
|
opentracing.set_tag("authenticated_entity", user_info.token_owner)
|
||||||
|
opentracing.set_tag("target_user", user_info.user_id)
|
||||||
if device_id:
|
if device_id:
|
||||||
opentracing.set_tag("device_id", device_id)
|
opentracing.set_tag("device_id", device_id)
|
||||||
|
|
||||||
@@ -255,6 +258,7 @@ class Auth:
|
|||||||
shadow_banned,
|
shadow_banned,
|
||||||
device_id,
|
device_id,
|
||||||
app_service=app_service,
|
app_service=app_service,
|
||||||
|
authenticated_entity=user_info.token_owner,
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
raise MissingClientTokenError()
|
raise MissingClientTokenError()
|
||||||
@@ -286,7 +290,7 @@ class Auth:
|
|||||||
|
|
||||||
async def get_user_by_access_token(
|
async def get_user_by_access_token(
|
||||||
self, token: str, rights: str = "access", allow_expired: bool = False,
|
self, token: str, rights: str = "access", allow_expired: bool = False,
|
||||||
) -> dict:
|
) -> TokenLookupResult:
|
||||||
""" Validate access token and get user_id from it
|
""" Validate access token and get user_id from it
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -295,13 +299,7 @@ class Auth:
|
|||||||
allow this
|
allow this
|
||||||
allow_expired: If False, raises an InvalidClientTokenError
|
allow_expired: If False, raises an InvalidClientTokenError
|
||||||
if the token is expired
|
if the token is expired
|
||||||
Returns:
|
|
||||||
dict that includes:
|
|
||||||
`user` (UserID)
|
|
||||||
`is_guest` (bool)
|
|
||||||
`shadow_banned` (bool)
|
|
||||||
`token_id` (int|None): access token id. May be None if guest
|
|
||||||
`device_id` (str|None): device corresponding to access token
|
|
||||||
Raises:
|
Raises:
|
||||||
InvalidClientTokenError if a user by that token exists, but the token is
|
InvalidClientTokenError if a user by that token exists, but the token is
|
||||||
expired
|
expired
|
||||||
@@ -311,9 +309,9 @@ class Auth:
|
|||||||
|
|
||||||
if rights == "access":
|
if rights == "access":
|
||||||
# first look in the database
|
# first look in the database
|
||||||
r = await self._look_up_user_by_access_token(token)
|
r = await self.store.get_user_by_access_token(token)
|
||||||
if r:
|
if r:
|
||||||
valid_until_ms = r["valid_until_ms"]
|
valid_until_ms = r.valid_until_ms
|
||||||
if (
|
if (
|
||||||
not allow_expired
|
not allow_expired
|
||||||
and valid_until_ms is not None
|
and valid_until_ms is not None
|
||||||
@@ -330,7 +328,6 @@ class Auth:
|
|||||||
# otherwise it needs to be a valid macaroon
|
# otherwise it needs to be a valid macaroon
|
||||||
try:
|
try:
|
||||||
user_id, guest = self._parse_and_validate_macaroon(token, rights)
|
user_id, guest = self._parse_and_validate_macaroon(token, rights)
|
||||||
user = UserID.from_string(user_id)
|
|
||||||
|
|
||||||
if rights == "access":
|
if rights == "access":
|
||||||
if not guest:
|
if not guest:
|
||||||
@@ -356,23 +353,17 @@ class Auth:
|
|||||||
raise InvalidClientTokenError(
|
raise InvalidClientTokenError(
|
||||||
"Guest access token used for regular user"
|
"Guest access token used for regular user"
|
||||||
)
|
)
|
||||||
ret = {
|
|
||||||
"user": user,
|
ret = TokenLookupResult(
|
||||||
"is_guest": True,
|
user_id=user_id,
|
||||||
"shadow_banned": False,
|
is_guest=True,
|
||||||
"token_id": None,
|
|
||||||
# all guests get the same device id
|
# all guests get the same device id
|
||||||
"device_id": GUEST_DEVICE_ID,
|
device_id=GUEST_DEVICE_ID,
|
||||||
}
|
)
|
||||||
elif rights == "delete_pusher":
|
elif rights == "delete_pusher":
|
||||||
# We don't store these tokens in the database
|
# We don't store these tokens in the database
|
||||||
ret = {
|
|
||||||
"user": user,
|
ret = TokenLookupResult(user_id=user_id, is_guest=False)
|
||||||
"is_guest": False,
|
|
||||||
"shadow_banned": False,
|
|
||||||
"token_id": None,
|
|
||||||
"device_id": None,
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Unknown rights setting %s", rights)
|
raise RuntimeError("Unknown rights setting %s", rights)
|
||||||
return ret
|
return ret
|
||||||
@@ -481,24 +472,6 @@ class Auth:
|
|||||||
now = self.hs.get_clock().time_msec()
|
now = self.hs.get_clock().time_msec()
|
||||||
return now < expiry
|
return now < expiry
|
||||||
|
|
||||||
async def _look_up_user_by_access_token(self, token):
|
|
||||||
ret = await self.store.get_user_by_access_token(token)
|
|
||||||
if not ret:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# we use ret.get() below because *lots* of unit tests stub out
|
|
||||||
# get_user_by_access_token in a way where it only returns a couple of
|
|
||||||
# the fields.
|
|
||||||
user_info = {
|
|
||||||
"user": UserID.from_string(ret.get("name")),
|
|
||||||
"token_id": ret.get("token_id", None),
|
|
||||||
"is_guest": False,
|
|
||||||
"shadow_banned": ret.get("shadow_banned"),
|
|
||||||
"device_id": ret.get("device_id"),
|
|
||||||
"valid_until_ms": ret.get("valid_until_ms"),
|
|
||||||
}
|
|
||||||
return user_info
|
|
||||||
|
|
||||||
def get_appservice_by_req(self, request):
|
def get_appservice_by_req(self, request):
|
||||||
token = self.get_access_token_from_request(request)
|
token = self.get_access_token_from_request(request)
|
||||||
service = self.store.get_app_service_by_token(token)
|
service = self.store.get_app_service_by_token(token)
|
||||||
|
|||||||
@@ -14,10 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from synapse.api.constants import LimitBlockingTypes, UserTypes
|
from synapse.api.constants import LimitBlockingTypes, UserTypes
|
||||||
from synapse.api.errors import Codes, ResourceLimitError
|
from synapse.api.errors import Codes, ResourceLimitError
|
||||||
from synapse.config.server import is_threepid_reserved
|
from synapse.config.server import is_threepid_reserved
|
||||||
|
from synapse.types import Requester
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -33,24 +35,41 @@ class AuthBlocking:
|
|||||||
self._max_mau_value = hs.config.max_mau_value
|
self._max_mau_value = hs.config.max_mau_value
|
||||||
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
|
self._limit_usage_by_mau = hs.config.limit_usage_by_mau
|
||||||
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
|
self._mau_limits_reserved_threepids = hs.config.mau_limits_reserved_threepids
|
||||||
|
self._server_name = hs.hostname
|
||||||
|
|
||||||
async def check_auth_blocking(self, user_id=None, threepid=None, user_type=None):
|
async def check_auth_blocking(
|
||||||
|
self,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
threepid: Optional[dict] = None,
|
||||||
|
user_type: Optional[str] = None,
|
||||||
|
requester: Optional[Requester] = None,
|
||||||
|
):
|
||||||
"""Checks if the user should be rejected for some external reason,
|
"""Checks if the user should be rejected for some external reason,
|
||||||
such as monthly active user limiting or global disable flag
|
such as monthly active user limiting or global disable flag
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_id(str|None): If present, checks for presence against existing
|
user_id: If present, checks for presence against existing
|
||||||
MAU cohort
|
MAU cohort
|
||||||
|
|
||||||
threepid(dict|None): If present, checks for presence against configured
|
threepid: If present, checks for presence against configured
|
||||||
reserved threepid. Used in cases where the user is trying register
|
reserved threepid. Used in cases where the user is trying register
|
||||||
with a MAU blocked server, normally they would be rejected but their
|
with a MAU blocked server, normally they would be rejected but their
|
||||||
threepid is on the reserved list. user_id and
|
threepid is on the reserved list. user_id and
|
||||||
threepid should never be set at the same time.
|
threepid should never be set at the same time.
|
||||||
|
|
||||||
user_type(str|None): If present, is used to decide whether to check against
|
user_type: If present, is used to decide whether to check against
|
||||||
certain blocking reasons like MAU.
|
certain blocking reasons like MAU.
|
||||||
|
|
||||||
|
requester: If present, and the authenticated entity is a user, checks for
|
||||||
|
presence against existing MAU cohort.
|
||||||
"""
|
"""
|
||||||
|
if requester:
|
||||||
|
if requester.authenticated_entity.startswith("@"):
|
||||||
|
user_id = requester.authenticated_entity
|
||||||
|
elif requester.authenticated_entity == self._server_name:
|
||||||
|
# We never block the server from doing actions on behalf of
|
||||||
|
# users.
|
||||||
|
return
|
||||||
|
|
||||||
# Never fail an auth check for the server notices users or support user
|
# Never fail an auth check for the server notices users or support user
|
||||||
# This can be a problem where event creation is prohibited due to blocking
|
# This can be a problem where event creation is prohibited due to blocking
|
||||||
|
|||||||
@@ -169,7 +169,9 @@ class BaseHandler:
|
|||||||
# and having homeservers have their own users leave keeps more
|
# and having homeservers have their own users leave keeps more
|
||||||
# of that decision-making and control local to the guest-having
|
# of that decision-making and control local to the guest-having
|
||||||
# homeserver.
|
# homeserver.
|
||||||
requester = synapse.types.create_requester(target_user, is_guest=True)
|
requester = synapse.types.create_requester(
|
||||||
|
target_user, is_guest=True, authenticated_entity=self.server_name
|
||||||
|
)
|
||||||
handler = self.hs.get_room_member_handler()
|
handler = self.hs.get_room_member_handler()
|
||||||
await handler.update_membership(
|
await handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
|
|||||||
@@ -686,8 +686,12 @@ class AuthHandler(BaseHandler):
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def get_access_token_for_user_id(
|
async def get_access_token_for_user_id(
|
||||||
self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
|
self,
|
||||||
):
|
user_id: str,
|
||||||
|
device_id: Optional[str],
|
||||||
|
valid_until_ms: Optional[int],
|
||||||
|
puppets_user_id: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Creates a new access token for the user with the given user ID.
|
Creates a new access token for the user with the given user ID.
|
||||||
|
|
||||||
@@ -713,13 +717,25 @@ class AuthHandler(BaseHandler):
|
|||||||
fmt_expiry = time.strftime(
|
fmt_expiry = time.strftime(
|
||||||
" until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0)
|
" until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0)
|
||||||
)
|
)
|
||||||
logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
|
|
||||||
|
if puppets_user_id:
|
||||||
|
logger.info(
|
||||||
|
"Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
|
||||||
|
)
|
||||||
|
|
||||||
await self.auth.check_auth_blocking(user_id)
|
await self.auth.check_auth_blocking(user_id)
|
||||||
|
|
||||||
access_token = self.macaroon_gen.generate_access_token(user_id)
|
access_token = self.macaroon_gen.generate_access_token(user_id)
|
||||||
await self.store.add_access_token_to_user(
|
await self.store.add_access_token_to_user(
|
||||||
user_id, access_token, device_id, valid_until_ms
|
user_id=user_id,
|
||||||
|
token=access_token,
|
||||||
|
device_id=device_id,
|
||||||
|
valid_until_ms=valid_until_ms,
|
||||||
|
puppets_user_id=puppets_user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# the device *should* have been registered before we got here; however,
|
# the device *should* have been registered before we got here; however,
|
||||||
@@ -984,17 +1000,17 @@ class AuthHandler(BaseHandler):
|
|||||||
# This might return an awaitable, if it does block the log out
|
# This might return an awaitable, if it does block the log out
|
||||||
# until it completes.
|
# until it completes.
|
||||||
result = provider.on_logged_out(
|
result = provider.on_logged_out(
|
||||||
user_id=str(user_info["user"]),
|
user_id=user_info.user_id,
|
||||||
device_id=user_info["device_id"],
|
device_id=user_info.device_id,
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
if inspect.isawaitable(result):
|
if inspect.isawaitable(result):
|
||||||
await result
|
await result
|
||||||
|
|
||||||
# delete pushers associated with this access token
|
# delete pushers associated with this access token
|
||||||
if user_info["token_id"] is not None:
|
if user_info.token_id is not None:
|
||||||
await self.hs.get_pusherpool().remove_pushers_by_access_token(
|
await self.hs.get_pusherpool().remove_pushers_by_access_token(
|
||||||
str(user_info["user"]), (user_info["token_id"],)
|
user_info.user_id, (user_info.token_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def delete_access_tokens_for_user(
|
async def delete_access_tokens_for_user(
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class DeactivateAccountHandler(BaseHandler):
|
|||||||
self._room_member_handler = hs.get_room_member_handler()
|
self._room_member_handler = hs.get_room_member_handler()
|
||||||
self._identity_handler = hs.get_identity_handler()
|
self._identity_handler = hs.get_identity_handler()
|
||||||
self.user_directory_handler = hs.get_user_directory_handler()
|
self.user_directory_handler = hs.get_user_directory_handler()
|
||||||
|
self._server_name = hs.hostname
|
||||||
|
|
||||||
# Flag that indicates whether the process to part users from rooms is running
|
# Flag that indicates whether the process to part users from rooms is running
|
||||||
self._user_parter_running = False
|
self._user_parter_running = False
|
||||||
@@ -152,7 +153,7 @@ class DeactivateAccountHandler(BaseHandler):
|
|||||||
for room in pending_invites:
|
for room in pending_invites:
|
||||||
try:
|
try:
|
||||||
await self._room_member_handler.update_membership(
|
await self._room_member_handler.update_membership(
|
||||||
create_requester(user),
|
create_requester(user, authenticated_entity=self._server_name),
|
||||||
user,
|
user,
|
||||||
room.room_id,
|
room.room_id,
|
||||||
"leave",
|
"leave",
|
||||||
@@ -208,7 +209,7 @@ class DeactivateAccountHandler(BaseHandler):
|
|||||||
logger.info("User parter parting %r from %r", user_id, room_id)
|
logger.info("User parter parting %r from %r", user_id, room_id)
|
||||||
try:
|
try:
|
||||||
await self._room_member_handler.update_membership(
|
await self._room_member_handler.update_membership(
|
||||||
create_requester(user),
|
create_requester(user, authenticated_entity=self._server_name),
|
||||||
user,
|
user,
|
||||||
room_id,
|
room_id,
|
||||||
"leave",
|
"leave",
|
||||||
|
|||||||
@@ -473,7 +473,7 @@ class EventCreationHandler:
|
|||||||
Returns:
|
Returns:
|
||||||
Tuple of created event, Context
|
Tuple of created event, Context
|
||||||
"""
|
"""
|
||||||
await self.auth.check_auth_blocking(requester.user.to_string())
|
await self.auth.check_auth_blocking(requester=requester)
|
||||||
|
|
||||||
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
|
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
|
||||||
room_version = event_dict["content"]["room_version"]
|
room_version = event_dict["content"]["room_version"]
|
||||||
@@ -620,7 +620,13 @@ class EventCreationHandler:
|
|||||||
if requester.app_service is not None:
|
if requester.app_service is not None:
|
||||||
return
|
return
|
||||||
|
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.authenticated_entity
|
||||||
|
if not user_id.startswith("@"):
|
||||||
|
# The authenticated entity might not be a user, e.g. if it's the
|
||||||
|
# server puppetting the user.
|
||||||
|
return
|
||||||
|
|
||||||
|
user = UserID.from_string(user_id)
|
||||||
|
|
||||||
# exempt the system notices user
|
# exempt the system notices user
|
||||||
if (
|
if (
|
||||||
@@ -640,9 +646,7 @@ class EventCreationHandler:
|
|||||||
if u["consent_version"] == self.config.user_consent_version:
|
if u["consent_version"] == self.config.user_consent_version:
|
||||||
return
|
return
|
||||||
|
|
||||||
consent_uri = self._consent_uri_builder.build_user_consent_uri(
|
consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
|
||||||
requester.user.localpart
|
|
||||||
)
|
|
||||||
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
|
msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
|
||||||
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
|
raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
|
||||||
|
|
||||||
@@ -1271,7 +1275,9 @@ class EventCreationHandler:
|
|||||||
for user_id in members:
|
for user_id in members:
|
||||||
if not self.hs.is_mine_id(user_id):
|
if not self.hs.is_mine_id(user_id):
|
||||||
continue
|
continue
|
||||||
requester = create_requester(user_id)
|
requester = create_requester(
|
||||||
|
user_id, authenticated_entity=self._server_name
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
event, context = await self.create_event(
|
event, context = await self.create_event(
|
||||||
requester,
|
requester,
|
||||||
|
|||||||
@@ -183,7 +183,9 @@ class ProfileHandler(BaseHandler):
|
|||||||
# the join event to update the displayname in the rooms.
|
# the join event to update the displayname in the rooms.
|
||||||
# This must be done by the target user himself.
|
# This must be done by the target user himself.
|
||||||
if by_admin:
|
if by_admin:
|
||||||
requester = create_requester(target_user)
|
requester = create_requester(
|
||||||
|
target_user, authenticated_entity=requester.authenticated_entity,
|
||||||
|
)
|
||||||
|
|
||||||
await self.store.set_profile_displayname(target_user.localpart, new_displayname)
|
await self.store.set_profile_displayname(target_user.localpart, new_displayname)
|
||||||
|
|
||||||
@@ -255,7 +257,9 @@ class ProfileHandler(BaseHandler):
|
|||||||
|
|
||||||
# Same like set_displayname
|
# Same like set_displayname
|
||||||
if by_admin:
|
if by_admin:
|
||||||
requester = create_requester(target_user)
|
requester = create_requester(
|
||||||
|
target_user, authenticated_entity=requester.authenticated_entity
|
||||||
|
)
|
||||||
|
|
||||||
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
|
await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
self.ratelimiter = hs.get_registration_ratelimiter()
|
self.ratelimiter = hs.get_registration_ratelimiter()
|
||||||
self.macaroon_gen = hs.get_macaroon_generator()
|
self.macaroon_gen = hs.get_macaroon_generator()
|
||||||
self._server_notices_mxid = hs.config.server_notices_mxid
|
self._server_notices_mxid = hs.config.server_notices_mxid
|
||||||
|
self._server_name = hs.hostname
|
||||||
|
|
||||||
self.spam_checker = hs.get_spam_checker()
|
self.spam_checker = hs.get_spam_checker()
|
||||||
|
|
||||||
@@ -115,7 +116,10 @@ class RegistrationHandler(BaseHandler):
|
|||||||
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
400, "User ID already taken.", errcode=Codes.USER_IN_USE
|
||||||
)
|
)
|
||||||
user_data = await self.auth.get_user_by_access_token(guest_access_token)
|
user_data = await self.auth.get_user_by_access_token(guest_access_token)
|
||||||
if not user_data["is_guest"] or user_data["user"].localpart != localpart:
|
if (
|
||||||
|
not user_data.is_guest
|
||||||
|
or UserID.from_string(user_data.user_id).localpart != localpart
|
||||||
|
):
|
||||||
raise AuthError(
|
raise AuthError(
|
||||||
403,
|
403,
|
||||||
"Cannot register taken user ID without valid guest "
|
"Cannot register taken user ID without valid guest "
|
||||||
@@ -314,7 +318,8 @@ class RegistrationHandler(BaseHandler):
|
|||||||
requires_join = False
|
requires_join = False
|
||||||
if self.hs.config.registration.auto_join_user_id:
|
if self.hs.config.registration.auto_join_user_id:
|
||||||
fake_requester = create_requester(
|
fake_requester = create_requester(
|
||||||
self.hs.config.registration.auto_join_user_id
|
self.hs.config.registration.auto_join_user_id,
|
||||||
|
authenticated_entity=self._server_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If the room requires an invite, add the user to the list of invites.
|
# If the room requires an invite, add the user to the list of invites.
|
||||||
@@ -326,7 +331,9 @@ class RegistrationHandler(BaseHandler):
|
|||||||
# being necessary this will occur after the invite was sent.
|
# being necessary this will occur after the invite was sent.
|
||||||
requires_join = True
|
requires_join = True
|
||||||
else:
|
else:
|
||||||
fake_requester = create_requester(user_id)
|
fake_requester = create_requester(
|
||||||
|
user_id, authenticated_entity=self._server_name
|
||||||
|
)
|
||||||
|
|
||||||
# Choose whether to federate the new room.
|
# Choose whether to federate the new room.
|
||||||
if not self.hs.config.registration.autocreate_auto_join_rooms_federated:
|
if not self.hs.config.registration.autocreate_auto_join_rooms_federated:
|
||||||
@@ -359,7 +366,9 @@ class RegistrationHandler(BaseHandler):
|
|||||||
# created it, then ensure the first user joins it.
|
# created it, then ensure the first user joins it.
|
||||||
if requires_join:
|
if requires_join:
|
||||||
await room_member_handler.update_membership(
|
await room_member_handler.update_membership(
|
||||||
requester=create_requester(user_id),
|
requester=create_requester(
|
||||||
|
user_id, authenticated_entity=self._server_name
|
||||||
|
),
|
||||||
target=UserID.from_string(user_id),
|
target=UserID.from_string(user_id),
|
||||||
room_id=info["room_id"],
|
room_id=info["room_id"],
|
||||||
# Since it was just created, there are no remote hosts.
|
# Since it was just created, there are no remote hosts.
|
||||||
@@ -423,7 +432,8 @@ class RegistrationHandler(BaseHandler):
|
|||||||
if requires_invite:
|
if requires_invite:
|
||||||
await room_member_handler.update_membership(
|
await room_member_handler.update_membership(
|
||||||
requester=create_requester(
|
requester=create_requester(
|
||||||
self.hs.config.registration.auto_join_user_id
|
self.hs.config.registration.auto_join_user_id,
|
||||||
|
authenticated_entity=self._server_name,
|
||||||
),
|
),
|
||||||
target=UserID.from_string(user_id),
|
target=UserID.from_string(user_id),
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
@@ -434,7 +444,9 @@ class RegistrationHandler(BaseHandler):
|
|||||||
|
|
||||||
# Send the join.
|
# Send the join.
|
||||||
await room_member_handler.update_membership(
|
await room_member_handler.update_membership(
|
||||||
requester=create_requester(user_id),
|
requester=create_requester(
|
||||||
|
user_id, authenticated_entity=self._server_name
|
||||||
|
),
|
||||||
target=UserID.from_string(user_id),
|
target=UserID.from_string(user_id),
|
||||||
room_id=room_id,
|
room_id=room_id,
|
||||||
remote_room_hosts=remote_room_hosts,
|
remote_room_hosts=remote_room_hosts,
|
||||||
@@ -741,7 +753,7 @@ class RegistrationHandler(BaseHandler):
|
|||||||
# up when the access token is saved, but that's quite an
|
# up when the access token is saved, but that's quite an
|
||||||
# invasive change I'd rather do separately.
|
# invasive change I'd rather do separately.
|
||||||
user_tuple = await self.store.get_user_by_access_token(token)
|
user_tuple = await self.store.get_user_by_access_token(token)
|
||||||
token_id = user_tuple["token_id"]
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
await self.pusher_pool.add_pusher(
|
await self.pusher_pool.add_pusher(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
|||||||
@@ -587,7 +587,7 @@ class RoomCreationHandler(BaseHandler):
|
|||||||
"""
|
"""
|
||||||
user_id = requester.user.to_string()
|
user_id = requester.user.to_string()
|
||||||
|
|
||||||
await self.auth.check_auth_blocking(user_id)
|
await self.auth.check_auth_blocking(requester=requester)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self._server_notices_mxid is not None
|
self._server_notices_mxid is not None
|
||||||
@@ -1250,7 +1250,9 @@ class RoomShutdownHandler:
|
|||||||
400, "User must be our own: %s" % (new_room_user_id,)
|
400, "User must be our own: %s" % (new_room_user_id,)
|
||||||
)
|
)
|
||||||
|
|
||||||
room_creator_requester = create_requester(new_room_user_id)
|
room_creator_requester = create_requester(
|
||||||
|
new_room_user_id, authenticated_entity=requester_user_id
|
||||||
|
)
|
||||||
|
|
||||||
info, stream_id = await self._room_creation_handler.create_room(
|
info, stream_id = await self._room_creation_handler.create_room(
|
||||||
room_creator_requester,
|
room_creator_requester,
|
||||||
@@ -1290,7 +1292,9 @@ class RoomShutdownHandler:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Kick users from room
|
# Kick users from room
|
||||||
target_requester = create_requester(user_id)
|
target_requester = create_requester(
|
||||||
|
user_id, authenticated_entity=requester_user_id
|
||||||
|
)
|
||||||
_, stream_id = await self.room_member_handler.update_membership(
|
_, stream_id = await self.room_member_handler.update_membership(
|
||||||
requester=target_requester,
|
requester=target_requester,
|
||||||
target=target_requester.user,
|
target=target_requester.user,
|
||||||
|
|||||||
@@ -961,6 +961,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||||||
|
|
||||||
self.distributor = hs.get_distributor()
|
self.distributor = hs.get_distributor()
|
||||||
self.distributor.declare("user_left_room")
|
self.distributor.declare("user_left_room")
|
||||||
|
self._server_name = hs.hostname
|
||||||
|
|
||||||
async def _is_remote_room_too_complex(
|
async def _is_remote_room_too_complex(
|
||||||
self, room_id: str, remote_room_hosts: List[str]
|
self, room_id: str, remote_room_hosts: List[str]
|
||||||
@@ -1055,7 +1056,9 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
|||||||
return event_id, stream_id
|
return event_id, stream_id
|
||||||
|
|
||||||
# The room is too large. Leave.
|
# The room is too large. Leave.
|
||||||
requester = types.create_requester(user, None, False, False, None)
|
requester = types.create_requester(
|
||||||
|
user, authenticated_entity=self._server_name
|
||||||
|
)
|
||||||
await self.update_membership(
|
await self.update_membership(
|
||||||
requester=requester, target=user, room_id=room_id, action="leave"
|
requester=requester, target=user, room_id=room_id, action="leave"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ from synapse.types import (
|
|||||||
Collection,
|
Collection,
|
||||||
JsonDict,
|
JsonDict,
|
||||||
MutableStateMap,
|
MutableStateMap,
|
||||||
|
Requester,
|
||||||
RoomStreamToken,
|
RoomStreamToken,
|
||||||
StateMap,
|
StateMap,
|
||||||
StreamToken,
|
StreamToken,
|
||||||
@@ -260,6 +261,7 @@ class SyncHandler:
|
|||||||
|
|
||||||
async def wait_for_sync_for_user(
|
async def wait_for_sync_for_user(
|
||||||
self,
|
self,
|
||||||
|
requester: Requester,
|
||||||
sync_config: SyncConfig,
|
sync_config: SyncConfig,
|
||||||
since_token: Optional[StreamToken] = None,
|
since_token: Optional[StreamToken] = None,
|
||||||
timeout: int = 0,
|
timeout: int = 0,
|
||||||
@@ -273,7 +275,7 @@ class SyncHandler:
|
|||||||
# not been exceeded (if not part of the group by this point, almost certain
|
# not been exceeded (if not part of the group by this point, almost certain
|
||||||
# auth_blocking will occur)
|
# auth_blocking will occur)
|
||||||
user_id = sync_config.user.to_string()
|
user_id = sync_config.user.to_string()
|
||||||
await self.auth.check_auth_blocking(user_id)
|
await self.auth.check_auth_blocking(requester=requester)
|
||||||
|
|
||||||
res = await self.response_cache.wrap(
|
res = await self.response_cache.wrap(
|
||||||
sync_config.request_key,
|
sync_config.request_key,
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ class SynapseRequest(Request):
|
|||||||
self.site = channel.site
|
self.site = channel.site
|
||||||
self._channel = channel # this is used by the tests
|
self._channel = channel # this is used by the tests
|
||||||
self.authenticated_entity = None
|
self.authenticated_entity = None
|
||||||
|
self.target_user = None
|
||||||
self.start_time = 0.0
|
self.start_time = 0.0
|
||||||
|
|
||||||
# we can't yet create the logcontext, as we don't know the method.
|
# we can't yet create the logcontext, as we don't know the method.
|
||||||
@@ -269,6 +270,11 @@ class SynapseRequest(Request):
|
|||||||
if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
|
if authenticated_entity is not None and isinstance(authenticated_entity, bytes):
|
||||||
authenticated_entity = authenticated_entity.decode("utf-8", "replace")
|
authenticated_entity = authenticated_entity.decode("utf-8", "replace")
|
||||||
|
|
||||||
|
if self.target_user:
|
||||||
|
authenticated_entity = "{} as {}".format(
|
||||||
|
authenticated_entity, self.target_user,
|
||||||
|
)
|
||||||
|
|
||||||
# ...or could be raw utf-8 bytes in the User-Agent header.
|
# ...or could be raw utf-8 bytes in the User-Agent header.
|
||||||
# N.B. if you don't do this, the logger explodes cryptically
|
# N.B. if you don't do this, the logger explodes cryptically
|
||||||
# with maximum recursion trying to log errors about
|
# with maximum recursion trying to log errors about
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class ModuleApi:
|
|||||||
self._store = hs.get_datastore()
|
self._store = hs.get_datastore()
|
||||||
self._auth = hs.get_auth()
|
self._auth = hs.get_auth()
|
||||||
self._auth_handler = auth_handler
|
self._auth_handler = auth_handler
|
||||||
|
self._server_name = hs.hostname
|
||||||
|
|
||||||
# We expose these as properties below in order to attach a helpful docstring.
|
# We expose these as properties below in order to attach a helpful docstring.
|
||||||
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
|
self._http_client = hs.get_simple_http_client() # type: SimpleHttpClient
|
||||||
@@ -336,7 +337,9 @@ class ModuleApi:
|
|||||||
SynapseError if the event was not allowed.
|
SynapseError if the event was not allowed.
|
||||||
"""
|
"""
|
||||||
# Create a requester object
|
# Create a requester object
|
||||||
requester = create_requester(event_dict["sender"])
|
requester = create_requester(
|
||||||
|
event_dict["sender"], authenticated_entity=self._server_name
|
||||||
|
)
|
||||||
|
|
||||||
# Create and send the event
|
# Create and send the event
|
||||||
(
|
(
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ from synapse.rest.admin.users import (
|
|||||||
UserRestServletV2,
|
UserRestServletV2,
|
||||||
UsersRestServlet,
|
UsersRestServlet,
|
||||||
UsersRestServletV2,
|
UsersRestServletV2,
|
||||||
|
UserTokenRestServlet,
|
||||||
WhoisRestServlet,
|
WhoisRestServlet,
|
||||||
)
|
)
|
||||||
from synapse.types import RoomStreamToken
|
from synapse.types import RoomStreamToken
|
||||||
@@ -216,6 +217,7 @@ def register_servlets(hs, http_server):
|
|||||||
VersionServlet(hs).register(http_server)
|
VersionServlet(hs).register(http_server)
|
||||||
UserAdminServlet(hs).register(http_server)
|
UserAdminServlet(hs).register(http_server)
|
||||||
UserMembershipRestServlet(hs).register(http_server)
|
UserMembershipRestServlet(hs).register(http_server)
|
||||||
|
UserTokenRestServlet(hs).register(http_server)
|
||||||
UserRestServletV2(hs).register(http_server)
|
UserRestServletV2(hs).register(http_server)
|
||||||
UsersRestServletV2(hs).register(http_server)
|
UsersRestServletV2(hs).register(http_server)
|
||||||
DeviceRestServlet(hs).register(http_server)
|
DeviceRestServlet(hs).register(http_server)
|
||||||
|
|||||||
@@ -309,7 +309,9 @@ class JoinRoomAliasServlet(RestServlet):
|
|||||||
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
||||||
)
|
)
|
||||||
|
|
||||||
fake_requester = create_requester(target_user)
|
fake_requester = create_requester(
|
||||||
|
target_user, authenticated_entity=requester.authenticated_entity
|
||||||
|
)
|
||||||
|
|
||||||
# send invite if room has "JoinRules.INVITE"
|
# send invite if room has "JoinRules.INVITE"
|
||||||
room_state = await self.state_handler.get_current_state(room_id)
|
room_state = await self.state_handler.get_current_state(room_id)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import hashlib
|
|||||||
import hmac
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||||
@@ -35,6 +36,9 @@ from synapse.rest.admin._base import (
|
|||||||
)
|
)
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -708,3 +712,42 @@ class UserMembershipRestServlet(RestServlet):
|
|||||||
|
|
||||||
ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
|
ret = {"joined_rooms": list(room_ids), "total": len(room_ids)}
|
||||||
return 200, ret
|
return 200, ret
|
||||||
|
|
||||||
|
|
||||||
|
class UserTokenRestServlet(RestServlet):
|
||||||
|
"""An admin API for logging in as a user.
|
||||||
|
"""
|
||||||
|
|
||||||
|
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/login$")
|
||||||
|
|
||||||
|
def __init__(self, hs: "HomeServer"):
|
||||||
|
self.hs = hs
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
self.auth = hs.get_auth()
|
||||||
|
self.auth_handler = hs.get_auth_handler()
|
||||||
|
|
||||||
|
async def on_PUT(self, request, user_id):
|
||||||
|
requester = await self.auth.get_user_by_req(request)
|
||||||
|
await assert_user_is_admin(self.auth, requester.user)
|
||||||
|
auth_user = requester.user
|
||||||
|
|
||||||
|
if not self.hs.is_mine_id(user_id):
|
||||||
|
raise SynapseError(400, "Only local users can be logged in as")
|
||||||
|
|
||||||
|
body = parse_json_object_from_request(request)
|
||||||
|
|
||||||
|
valid_until_ms = body.get("valid_until_ms")
|
||||||
|
if valid_until_ms and not isinstance(valid_until_ms, int):
|
||||||
|
raise SynapseError(400, "'valid_until_ms' parameter must be an int")
|
||||||
|
|
||||||
|
if auth_user.to_string() == user_id:
|
||||||
|
raise SynapseError(400, "Cannot use admin API to login as self")
|
||||||
|
|
||||||
|
token = await self.auth_handler.get_access_token_for_user_id(
|
||||||
|
user_id=auth_user.to_string(),
|
||||||
|
device_id=None,
|
||||||
|
valid_until_ms=valid_until_ms,
|
||||||
|
puppets_user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return 200, {"access_token": token}
|
||||||
|
|||||||
@@ -171,6 +171,7 @@ class SyncRestServlet(RestServlet):
|
|||||||
)
|
)
|
||||||
with context:
|
with context:
|
||||||
sync_result = await self.sync_handler.wait_for_sync_for_user(
|
sync_result = await self.sync_handler.wait_for_sync_for_user(
|
||||||
|
requester,
|
||||||
sync_config,
|
sync_config,
|
||||||
since_token=since_token,
|
since_token=since_token,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ class ServerNoticesManager:
|
|||||||
self._room_member_handler = hs.get_room_member_handler()
|
self._room_member_handler = hs.get_room_member_handler()
|
||||||
self._event_creation_handler = hs.get_event_creation_handler()
|
self._event_creation_handler = hs.get_event_creation_handler()
|
||||||
self._is_mine_id = hs.is_mine_id
|
self._is_mine_id = hs.is_mine_id
|
||||||
|
self._server_name = hs.hostname
|
||||||
|
|
||||||
self._notifier = hs.get_notifier()
|
self._notifier = hs.get_notifier()
|
||||||
self.server_notices_mxid = self._config.server_notices_mxid
|
self.server_notices_mxid = self._config.server_notices_mxid
|
||||||
@@ -72,7 +73,9 @@ class ServerNoticesManager:
|
|||||||
await self.maybe_invite_user_to_room(user_id, room_id)
|
await self.maybe_invite_user_to_room(user_id, room_id)
|
||||||
|
|
||||||
system_mxid = self._config.server_notices_mxid
|
system_mxid = self._config.server_notices_mxid
|
||||||
requester = create_requester(system_mxid)
|
requester = create_requester(
|
||||||
|
system_mxid, authenticated_entity=self._server_name
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Sending server notice to %s", user_id)
|
logger.info("Sending server notice to %s", user_id)
|
||||||
|
|
||||||
@@ -145,7 +148,9 @@ class ServerNoticesManager:
|
|||||||
"avatar_url": self._config.server_notices_mxid_avatar_url,
|
"avatar_url": self._config.server_notices_mxid_avatar_url,
|
||||||
}
|
}
|
||||||
|
|
||||||
requester = create_requester(self.server_notices_mxid)
|
requester = create_requester(
|
||||||
|
self.server_notices_mxid, authenticated_entity=self._server_name
|
||||||
|
)
|
||||||
info, _ = await self._room_creation_handler.create_room(
|
info, _ = await self._room_creation_handler.create_room(
|
||||||
requester,
|
requester,
|
||||||
config={
|
config={
|
||||||
@@ -174,7 +179,9 @@ class ServerNoticesManager:
|
|||||||
user_id: The ID of the user to invite.
|
user_id: The ID of the user to invite.
|
||||||
room_id: The ID of the room to invite the user to.
|
room_id: The ID of the room to invite the user to.
|
||||||
"""
|
"""
|
||||||
requester = create_requester(self.server_notices_mxid)
|
requester = create_requester(
|
||||||
|
self.server_notices_mxid, authenticated_entity=self._server_name
|
||||||
|
)
|
||||||
|
|
||||||
# Check whether the user has already joined or been invited to this room. If
|
# Check whether the user has already joined or been invited to this room. If
|
||||||
# that's the case, there is no need to re-invite them.
|
# that's the case, there is no need to re-invite them.
|
||||||
|
|||||||
@@ -146,7 +146,6 @@ class DataStore(
|
|||||||
db_conn, "e2e_cross_signing_keys", "stream_id"
|
db_conn, "e2e_cross_signing_keys", "stream_id"
|
||||||
)
|
)
|
||||||
|
|
||||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
|
||||||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
||||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||||
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
|
||||||
|
|||||||
@@ -16,29 +16,53 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import attr
|
||||||
|
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.storage._base import SQLBaseStore
|
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
|
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||||
|
from synapse.storage.databases.main.stats import StatsStore
|
||||||
from synapse.storage.types import Cursor
|
from synapse.storage.types import Cursor
|
||||||
|
from synapse.storage.util.id_generators import IdGenerator
|
||||||
from synapse.storage.util.sequence import build_sequence_generator
|
from synapse.storage.util.sequence import build_sequence_generator
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.server import HomeServer
|
||||||
|
|
||||||
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
|
THIRTY_MINUTES_IN_MS = 30 * 60 * 1000
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class RegistrationWorkerStore(SQLBaseStore):
|
@attr.s(frozen=True, slots=True)
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
class TokenLookupResult:
|
||||||
|
"""Result of looking up an access token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_id = attr.ib(type=str)
|
||||||
|
is_guest = attr.ib(type=bool, default=False)
|
||||||
|
shadow_banned = attr.ib(type=bool, default=False)
|
||||||
|
token_id = attr.ib(type=Optional[int], default=None)
|
||||||
|
device_id = attr.ib(type=Optional[str], default=None)
|
||||||
|
valid_until_ms = attr.ib(type=Optional[int], default=None)
|
||||||
|
token_owner = attr.ib(type=str)
|
||||||
|
|
||||||
|
@token_owner.default
|
||||||
|
def _default_token_owner(self):
|
||||||
|
return self.user_id
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
|
||||||
|
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
self.clock = hs.get_clock()
|
|
||||||
|
|
||||||
# Note: we don't check this sequence for consistency as we'd have to
|
# Note: we don't check this sequence for consistency as we'd have to
|
||||||
# call `find_max_generated_user_id_localpart` each time, which is
|
# call `find_max_generated_user_id_localpart` each time, which is
|
||||||
@@ -55,7 +79,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
# Create a background job for culling expired 3PID validity tokens
|
# Create a background job for culling expired 3PID validity tokens
|
||||||
if hs.config.run_background_tasks:
|
if hs.config.run_background_tasks:
|
||||||
self.clock.looping_call(
|
self._clock.looping_call(
|
||||||
self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
|
self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -92,13 +116,13 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
if not info:
|
if not info:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
now = self.clock.time_msec()
|
now = self._clock.time_msec()
|
||||||
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
|
trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000
|
||||||
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
|
is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms
|
||||||
return is_trial
|
return is_trial
|
||||||
|
|
||||||
@cached()
|
@cached()
|
||||||
async def get_user_by_access_token(self, token: str) -> Optional[dict]:
|
async def get_user_by_access_token(self, token: str) -> Optional[TokenLookupResult]:
|
||||||
"""Get a user from the given access token.
|
"""Get a user from the given access token.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -257,7 +281,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
return await self.db_pool.runInteraction(
|
return await self.db_pool.runInteraction(
|
||||||
"get_users_expiring_soon",
|
"get_users_expiring_soon",
|
||||||
select_users_txn,
|
select_users_txn,
|
||||||
self.clock.time_msec(),
|
self._clock.time_msec(),
|
||||||
self.config.account_validity.renew_at,
|
self.config.account_validity.renew_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -327,19 +351,24 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
|
await self.db_pool.runInteraction("set_server_admin", set_server_admin_txn)
|
||||||
|
|
||||||
def _query_for_auth(self, txn, token):
|
def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
|
||||||
sql = (
|
sql = """
|
||||||
"SELECT users.name, users.is_guest, users.shadow_banned, access_tokens.id as token_id,"
|
SELECT users.name as user_id,
|
||||||
" access_tokens.device_id, access_tokens.valid_until_ms"
|
users.is_guest,
|
||||||
" FROM users"
|
users.shadow_banned,
|
||||||
" INNER JOIN access_tokens on users.name = access_tokens.user_id"
|
access_tokens.id as token_id,
|
||||||
" WHERE token = ?"
|
access_tokens.device_id,
|
||||||
)
|
access_tokens.valid_until_ms,
|
||||||
|
access_tokens.user_id as token_owner
|
||||||
|
FROM users
|
||||||
|
INNER JOIN access_tokens on users.name = COALESCE(puppets_user_id, access_tokens.user_id)
|
||||||
|
WHERE token = ?
|
||||||
|
"""
|
||||||
|
|
||||||
txn.execute(sql, (token,))
|
txn.execute(sql, (token,))
|
||||||
rows = self.db_pool.cursor_to_dict(txn)
|
rows = self.db_pool.cursor_to_dict(txn)
|
||||||
if rows:
|
if rows:
|
||||||
return rows[0]
|
return TokenLookupResult(**rows[0])
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -803,7 +832,7 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
await self.db_pool.runInteraction(
|
await self.db_pool.runInteraction(
|
||||||
"cull_expired_threepid_validation_tokens",
|
"cull_expired_threepid_validation_tokens",
|
||||||
cull_expired_threepid_validation_tokens_txn,
|
cull_expired_threepid_validation_tokens_txn,
|
||||||
self.clock.time_msec(),
|
self._clock.time_msec(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@wrap_as_background_process("account_validity_set_expiration_dates")
|
@wrap_as_background_process("account_validity_set_expiration_dates")
|
||||||
@@ -890,10 +919,10 @@ class RegistrationWorkerStore(SQLBaseStore):
|
|||||||
|
|
||||||
|
|
||||||
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.clock = hs.get_clock()
|
self._clock = hs.get_clock()
|
||||||
self.config = hs.config
|
self.config = hs.config
|
||||||
|
|
||||||
self.db_pool.updates.register_background_index_update(
|
self.db_pool.updates.register_background_index_update(
|
||||||
@@ -1016,19 +1045,63 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
|
|||||||
|
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
async def set_user_deactivated_status(
|
||||||
|
self, user_id: str, deactivated: bool
|
||||||
|
) -> None:
|
||||||
|
"""Set the `deactivated` property for the provided user to the provided value.
|
||||||
|
|
||||||
class RegistrationStore(RegistrationBackgroundUpdateStore):
|
Args:
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
user_id: The ID of the user to set the status for.
|
||||||
|
deactivated: The value to set for `deactivated`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
await self.db_pool.runInteraction(
|
||||||
|
"set_user_deactivated_status",
|
||||||
|
self.set_user_deactivated_status_txn,
|
||||||
|
user_id,
|
||||||
|
deactivated,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
|
||||||
|
self.db_pool.simple_update_one_txn(
|
||||||
|
txn=txn,
|
||||||
|
table="users",
|
||||||
|
keyvalues={"name": user_id},
|
||||||
|
updatevalues={"deactivated": 1 if deactivated else 0},
|
||||||
|
)
|
||||||
|
self._invalidate_cache_and_stream(
|
||||||
|
txn, self.get_user_deactivated_status, (user_id,)
|
||||||
|
)
|
||||||
|
txn.call_after(self.is_guest.invalidate, (user_id,))
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
async def is_guest(self, user_id: str) -> bool:
|
||||||
|
res = await self.db_pool.simple_select_one_onecol(
|
||||||
|
table="users",
|
||||||
|
keyvalues={"name": user_id},
|
||||||
|
retcol="is_guest",
|
||||||
|
allow_none=True,
|
||||||
|
desc="is_guest",
|
||||||
|
)
|
||||||
|
|
||||||
|
return res if res else False
|
||||||
|
|
||||||
|
|
||||||
|
class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||||
|
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
|
self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors
|
||||||
|
|
||||||
|
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||||
|
|
||||||
async def add_access_token_to_user(
|
async def add_access_token_to_user(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
token: str,
|
token: str,
|
||||||
device_id: Optional[str],
|
device_id: Optional[str],
|
||||||
valid_until_ms: Optional[int],
|
valid_until_ms: Optional[int],
|
||||||
|
puppets_user_id: Optional[str] = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Adds an access token for the given user.
|
"""Adds an access token for the given user.
|
||||||
|
|
||||||
@@ -1052,6 +1125,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
"token": token,
|
"token": token,
|
||||||
"device_id": device_id,
|
"device_id": device_id,
|
||||||
"valid_until_ms": valid_until_ms,
|
"valid_until_ms": valid_until_ms,
|
||||||
|
"puppets_user_id": puppets_user_id,
|
||||||
},
|
},
|
||||||
desc="add_access_token_to_user",
|
desc="add_access_token_to_user",
|
||||||
)
|
)
|
||||||
@@ -1138,19 +1212,19 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
def _register_user(
|
def _register_user(
|
||||||
self,
|
self,
|
||||||
txn,
|
txn,
|
||||||
user_id,
|
user_id: str,
|
||||||
password_hash,
|
password_hash: Optional[str],
|
||||||
was_guest,
|
was_guest: bool,
|
||||||
make_guest,
|
make_guest: bool,
|
||||||
appservice_id,
|
appservice_id: Optional[str],
|
||||||
create_profile_with_displayname,
|
create_profile_with_displayname: Optional[str],
|
||||||
admin,
|
admin: bool,
|
||||||
user_type,
|
user_type: Optional[str],
|
||||||
shadow_banned,
|
shadow_banned: bool,
|
||||||
):
|
):
|
||||||
user_id_obj = UserID.from_string(user_id)
|
user_id_obj = UserID.from_string(user_id)
|
||||||
|
|
||||||
now = int(self.clock.time())
|
now = int(self._clock.time())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if was_guest:
|
if was_guest:
|
||||||
@@ -1374,18 +1448,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
|
|
||||||
await self.db_pool.runInteraction("delete_access_token", f)
|
await self.db_pool.runInteraction("delete_access_token", f)
|
||||||
|
|
||||||
@cached()
|
|
||||||
async def is_guest(self, user_id: str) -> bool:
|
|
||||||
res = await self.db_pool.simple_select_one_onecol(
|
|
||||||
table="users",
|
|
||||||
keyvalues={"name": user_id},
|
|
||||||
retcol="is_guest",
|
|
||||||
allow_none=True,
|
|
||||||
desc="is_guest",
|
|
||||||
)
|
|
||||||
|
|
||||||
return res if res else False
|
|
||||||
|
|
||||||
async def add_user_pending_deactivation(self, user_id: str) -> None:
|
async def add_user_pending_deactivation(self, user_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Adds a user to the table of users who need to be parted from all the rooms they're
|
Adds a user to the table of users who need to be parted from all the rooms they're
|
||||||
@@ -1479,7 +1541,7 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
txn,
|
txn,
|
||||||
table="threepid_validation_session",
|
table="threepid_validation_session",
|
||||||
keyvalues={"session_id": session_id},
|
keyvalues={"session_id": session_id},
|
||||||
updatevalues={"validated_at": self.clock.time_msec()},
|
updatevalues={"validated_at": self._clock.time_msec()},
|
||||||
)
|
)
|
||||||
|
|
||||||
return next_link
|
return next_link
|
||||||
@@ -1547,35 +1609,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
|||||||
start_or_continue_validation_session_txn,
|
start_or_continue_validation_session_txn,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_user_deactivated_status(
|
|
||||||
self, user_id: str, deactivated: bool
|
|
||||||
) -> None:
|
|
||||||
"""Set the `deactivated` property for the provided user to the provided value.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The ID of the user to set the status for.
|
|
||||||
deactivated: The value to set for `deactivated`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
await self.db_pool.runInteraction(
|
|
||||||
"set_user_deactivated_status",
|
|
||||||
self.set_user_deactivated_status_txn,
|
|
||||||
user_id,
|
|
||||||
deactivated,
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
|
|
||||||
self.db_pool.simple_update_one_txn(
|
|
||||||
txn=txn,
|
|
||||||
table="users",
|
|
||||||
keyvalues={"name": user_id},
|
|
||||||
updatevalues={"deactivated": 1 if deactivated else 0},
|
|
||||||
)
|
|
||||||
self._invalidate_cache_and_stream(
|
|
||||||
txn, self.get_user_deactivated_status, (user_id,)
|
|
||||||
)
|
|
||||||
txn.call_after(self.is_guest.invalidate, (user_id,))
|
|
||||||
|
|
||||||
|
|
||||||
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
|
def find_max_generated_user_id_localpart(cur: Cursor) -> int:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -0,0 +1,17 @@
|
|||||||
|
/* Copyright 2020 The Matrix.org Foundation C.I.C
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
-- Whether the access token is an admin token for controlling another user.
|
||||||
|
ALTER TABLE access_tokens ADD COLUMN puppets_user_id TEXT;
|
||||||
@@ -74,6 +74,7 @@ class Requester(
|
|||||||
"shadow_banned",
|
"shadow_banned",
|
||||||
"device_id",
|
"device_id",
|
||||||
"app_service",
|
"app_service",
|
||||||
|
"authenticated_entity",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
@@ -104,6 +105,7 @@ class Requester(
|
|||||||
"shadow_banned": self.shadow_banned,
|
"shadow_banned": self.shadow_banned,
|
||||||
"device_id": self.device_id,
|
"device_id": self.device_id,
|
||||||
"app_server_id": self.app_service.id if self.app_service else None,
|
"app_server_id": self.app_service.id if self.app_service else None,
|
||||||
|
"authenticated_entity": self.authenticated_entity,
|
||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -129,6 +131,7 @@ class Requester(
|
|||||||
shadow_banned=input["shadow_banned"],
|
shadow_banned=input["shadow_banned"],
|
||||||
device_id=input["device_id"],
|
device_id=input["device_id"],
|
||||||
app_service=appservice,
|
app_service=appservice,
|
||||||
|
authenticated_entity=input["authenticated_entity"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -139,6 +142,7 @@ def create_requester(
|
|||||||
shadow_banned=False,
|
shadow_banned=False,
|
||||||
device_id=None,
|
device_id=None,
|
||||||
app_service=None,
|
app_service=None,
|
||||||
|
authenticated_entity=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create a new ``Requester`` object
|
Create a new ``Requester`` object
|
||||||
@@ -151,14 +155,27 @@ def create_requester(
|
|||||||
shadow_banned (bool): True if the user making this request is shadow-banned.
|
shadow_banned (bool): True if the user making this request is shadow-banned.
|
||||||
device_id (str|None): device_id which was set at authentication time
|
device_id (str|None): device_id which was set at authentication time
|
||||||
app_service (ApplicationService|None): the AS requesting on behalf of the user
|
app_service (ApplicationService|None): the AS requesting on behalf of the user
|
||||||
|
authenticated_entity: The entity that authenticatd when making the request,
|
||||||
|
this is different than the user_id when an admin user or the server is
|
||||||
|
"puppeting" the user.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Requester
|
Requester
|
||||||
"""
|
"""
|
||||||
if not isinstance(user_id, UserID):
|
if not isinstance(user_id, UserID):
|
||||||
user_id = UserID.from_string(user_id)
|
user_id = UserID.from_string(user_id)
|
||||||
|
|
||||||
|
if authenticated_entity is None:
|
||||||
|
authenticated_entity = user_id.to_string()
|
||||||
|
|
||||||
return Requester(
|
return Requester(
|
||||||
user_id, access_token_id, is_guest, shadow_banned, device_id, app_service
|
user_id,
|
||||||
|
access_token_id,
|
||||||
|
is_guest,
|
||||||
|
shadow_banned,
|
||||||
|
device_id,
|
||||||
|
app_service,
|
||||||
|
authenticated_entity,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from synapse.api.errors import (
|
|||||||
MissingClientTokenError,
|
MissingClientTokenError,
|
||||||
ResourceLimitError,
|
ResourceLimitError,
|
||||||
)
|
)
|
||||||
|
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
@@ -61,7 +62,9 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_user_by_req_user_valid_token(self):
|
def test_get_user_by_req_user_valid_token(self):
|
||||||
user_info = {"name": self.test_user, "token_id": "ditto", "device_id": "device"}
|
user_info = TokenLookupResult(
|
||||||
|
user_id=self.test_user, token_id=5, device_id="device"
|
||||||
|
)
|
||||||
self.store.get_user_by_access_token = Mock(
|
self.store.get_user_by_access_token = Mock(
|
||||||
return_value=defer.succeed(user_info)
|
return_value=defer.succeed(user_info)
|
||||||
)
|
)
|
||||||
@@ -84,7 +87,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
self.assertEqual(f.errcode, "M_UNKNOWN_TOKEN")
|
||||||
|
|
||||||
def test_get_user_by_req_user_missing_token(self):
|
def test_get_user_by_req_user_missing_token(self):
|
||||||
user_info = {"name": self.test_user, "token_id": "ditto"}
|
user_info = TokenLookupResult(user_id=self.test_user, token_id=5)
|
||||||
self.store.get_user_by_access_token = Mock(
|
self.store.get_user_by_access_token = Mock(
|
||||||
return_value=defer.succeed(user_info)
|
return_value=defer.succeed(user_info)
|
||||||
)
|
)
|
||||||
@@ -221,7 +224,7 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
def test_get_user_from_macaroon(self):
|
def test_get_user_from_macaroon(self):
|
||||||
self.store.get_user_by_access_token = Mock(
|
self.store.get_user_by_access_token = Mock(
|
||||||
return_value=defer.succeed(
|
return_value=defer.succeed(
|
||||||
{"name": "@baldrick:matrix.org", "device_id": "device"}
|
TokenLookupResult(user_id="@baldrick:matrix.org", device_id="device")
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -237,12 +240,11 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
user_info = yield defer.ensureDeferred(
|
user_info = yield defer.ensureDeferred(
|
||||||
self.auth.get_user_by_access_token(macaroon.serialize())
|
self.auth.get_user_by_access_token(macaroon.serialize())
|
||||||
)
|
)
|
||||||
user = user_info["user"]
|
self.assertEqual(user_id, user_info.user_id)
|
||||||
self.assertEqual(UserID.from_string(user_id), user)
|
|
||||||
|
|
||||||
# TODO: device_id should come from the macaroon, but currently comes
|
# TODO: device_id should come from the macaroon, but currently comes
|
||||||
# from the db.
|
# from the db.
|
||||||
self.assertEqual(user_info["device_id"], "device")
|
self.assertEqual(user_info.device_id, "device")
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_get_guest_user_from_macaroon(self):
|
def test_get_guest_user_from_macaroon(self):
|
||||||
@@ -264,10 +266,8 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
user_info = yield defer.ensureDeferred(
|
user_info = yield defer.ensureDeferred(
|
||||||
self.auth.get_user_by_access_token(serialized)
|
self.auth.get_user_by_access_token(serialized)
|
||||||
)
|
)
|
||||||
user = user_info["user"]
|
self.assertEqual(user_id, user_info.user_id)
|
||||||
is_guest = user_info["is_guest"]
|
self.assertTrue(user_info.is_guest)
|
||||||
self.assertEqual(UserID.from_string(user_id), user)
|
|
||||||
self.assertTrue(is_guest)
|
|
||||||
self.store.get_user_by_id.assert_called_with(user_id)
|
self.store.get_user_by_id.assert_called_with(user_id)
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
@@ -289,12 +289,9 @@ class AuthTestCase(unittest.TestCase):
|
|||||||
if token != tok:
|
if token != tok:
|
||||||
return defer.succeed(None)
|
return defer.succeed(None)
|
||||||
return defer.succeed(
|
return defer.succeed(
|
||||||
{
|
TokenLookupResult(
|
||||||
"name": USER_ID,
|
user_id=USER_ID, is_guest=False, token_id=1234, device_id="DEVICE",
|
||||||
"is_guest": False,
|
)
|
||||||
"token_id": 1234,
|
|
||||||
"device_id": "DEVICE",
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.store.get_user_by_access_token = get_user
|
self.store.get_user_by_access_token = get_user
|
||||||
|
|||||||
@@ -289,7 +289,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
|||||||
# make sure that our device ID has changed
|
# make sure that our device ID has changed
|
||||||
user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
|
user_info = self.get_success(self.auth.get_user_by_access_token(access_token))
|
||||||
|
|
||||||
self.assertEqual(user_info["device_id"], retrieved_device_id)
|
self.assertEqual(user_info.device_id, retrieved_device_id)
|
||||||
|
|
||||||
# make sure the device has the display name that was set from the login
|
# make sure the device has the display name that was set from the login
|
||||||
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
|
res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
|
|||||||
self.info = self.get_success(
|
self.info = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(self.access_token,)
|
self.hs.get_datastore().get_user_by_access_token(self.access_token,)
|
||||||
)
|
)
|
||||||
self.token_id = self.info["token_id"]
|
self.token_id = self.info.token_id
|
||||||
|
|
||||||
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
|
self.requester = create_requester(self.user_id, access_token_id=self.token_id)
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
from synapse.api.errors import Codes, ResourceLimitError
|
from synapse.api.errors import Codes, ResourceLimitError
|
||||||
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
|
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION
|
||||||
from synapse.handlers.sync import SyncConfig
|
from synapse.handlers.sync import SyncConfig
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID, create_requester
|
||||||
|
|
||||||
import tests.unittest
|
import tests.unittest
|
||||||
import tests.utils
|
import tests.utils
|
||||||
@@ -38,6 +38,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
|||||||
user_id1 = "@user1:test"
|
user_id1 = "@user1:test"
|
||||||
user_id2 = "@user2:test"
|
user_id2 = "@user2:test"
|
||||||
sync_config = self._generate_sync_config(user_id1)
|
sync_config = self._generate_sync_config(user_id1)
|
||||||
|
requester = create_requester(user_id1)
|
||||||
|
|
||||||
self.reactor.advance(100) # So we get not 0 time
|
self.reactor.advance(100) # So we get not 0 time
|
||||||
self.auth_blocking._limit_usage_by_mau = True
|
self.auth_blocking._limit_usage_by_mau = True
|
||||||
@@ -45,21 +46,26 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
# Check that the happy case does not throw errors
|
# Check that the happy case does not throw errors
|
||||||
self.get_success(self.store.upsert_monthly_active_user(user_id1))
|
self.get_success(self.store.upsert_monthly_active_user(user_id1))
|
||||||
self.get_success(self.sync_handler.wait_for_sync_for_user(sync_config))
|
self.get_success(
|
||||||
|
self.sync_handler.wait_for_sync_for_user(requester, sync_config)
|
||||||
|
)
|
||||||
|
|
||||||
# Test that global lock works
|
# Test that global lock works
|
||||||
self.auth_blocking._hs_disabled = True
|
self.auth_blocking._hs_disabled = True
|
||||||
e = self.get_failure(
|
e = self.get_failure(
|
||||||
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
|
self.sync_handler.wait_for_sync_for_user(requester, sync_config),
|
||||||
|
ResourceLimitError,
|
||||||
)
|
)
|
||||||
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||||
|
|
||||||
self.auth_blocking._hs_disabled = False
|
self.auth_blocking._hs_disabled = False
|
||||||
|
|
||||||
sync_config = self._generate_sync_config(user_id2)
|
sync_config = self._generate_sync_config(user_id2)
|
||||||
|
requester = create_requester(user_id2)
|
||||||
|
|
||||||
e = self.get_failure(
|
e = self.get_failure(
|
||||||
self.sync_handler.wait_for_sync_for_user(sync_config), ResourceLimitError
|
self.sync_handler.wait_for_sync_for_user(requester, sync_config),
|
||||||
|
ResourceLimitError,
|
||||||
)
|
)
|
||||||
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
|
||||||
|
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class EmailPusherTests(HomeserverTestCase):
|
|||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(self.access_token)
|
self.hs.get_datastore().get_user_by_access_token(self.access_token)
|
||||||
)
|
)
|
||||||
token_id = user_tuple["token_id"]
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.pusher = self.get_success(
|
self.pusher = self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
|
|||||||
@@ -69,7 +69,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
token_id = user_tuple["token_id"]
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
@@ -181,7 +181,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
token_id = user_tuple["token_id"]
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
@@ -297,7 +297,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
token_id = user_tuple["token_id"]
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
@@ -379,7 +379,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
token_id = user_tuple["token_id"]
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
@@ -452,7 +452,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
|||||||
user_tuple = self.get_success(
|
user_tuple = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
token_id = user_tuple["token_id"]
|
token_id = user_tuple.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
|||||||
user_dict = self.get_success(
|
user_dict = self.get_success(
|
||||||
self.hs.get_datastore().get_user_by_access_token(access_token)
|
self.hs.get_datastore().get_user_by_access_token(access_token)
|
||||||
)
|
)
|
||||||
token_id = user_dict["token_id"]
|
token_id = user_dict.token_id
|
||||||
|
|
||||||
self.get_success(
|
self.get_success(
|
||||||
self.hs.get_pusherpool().add_pusher(
|
self.hs.get_pusherpool().add_pusher(
|
||||||
|
|||||||
@@ -23,8 +23,8 @@ from mock import Mock
|
|||||||
import synapse.rest.admin
|
import synapse.rest.admin
|
||||||
from synapse.api.constants import UserTypes
|
from synapse.api.constants import UserTypes
|
||||||
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
|
from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
|
||||||
from synapse.rest.client.v1 import login, room
|
from synapse.rest.client.v1 import login, logout, room
|
||||||
from synapse.rest.client.v2_alpha import sync
|
from synapse.rest.client.v2_alpha import devices, sync
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import make_awaitable
|
from tests.test_utils import make_awaitable
|
||||||
@@ -1101,3 +1101,244 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
|
|||||||
self.assertEqual(200, channel.code, msg=channel.json_body)
|
self.assertEqual(200, channel.code, msg=channel.json_body)
|
||||||
self.assertEqual(number_rooms, channel.json_body["total"])
|
self.assertEqual(number_rooms, channel.json_body["total"])
|
||||||
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
|
self.assertEqual(number_rooms, len(channel.json_body["joined_rooms"]))
|
||||||
|
|
||||||
|
|
||||||
|
class UserTokenRestTestCase(unittest.HomeserverTestCase):
|
||||||
|
"""Test for /_synapse/admin/v1/users/<user>/login
|
||||||
|
"""
|
||||||
|
|
||||||
|
servlets = [
|
||||||
|
synapse.rest.admin.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
sync.register_servlets,
|
||||||
|
room.register_servlets,
|
||||||
|
devices.register_servlets,
|
||||||
|
logout.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def prepare(self, reactor, clock, hs):
|
||||||
|
self.store = hs.get_datastore()
|
||||||
|
|
||||||
|
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||||
|
self.admin_user_tok = self.login("admin", "pass")
|
||||||
|
|
||||||
|
self.other_user = self.register_user("user", "pass")
|
||||||
|
self.other_user_tok = self.login("user", "pass")
|
||||||
|
self.url = "/_synapse/admin/v1/users/%s/login" % urllib.parse.quote(
|
||||||
|
self.other_user
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_token(self) -> str:
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"PUT", self.url, b"{}", access_token=self.admin_user_tok
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
return channel.json_body["access_token"]
|
||||||
|
|
||||||
|
def test_no_auth(self):
|
||||||
|
"""Try to login as a user without authentication.
|
||||||
|
"""
|
||||||
|
request, channel = self.make_request("PUT", self.url, b"{}")
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
|
||||||
|
|
||||||
|
def test_not_admin(self):
|
||||||
|
"""Try to login as a user as a non-admin user.
|
||||||
|
"""
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"PUT", self.url, b"{}", access_token=self.other_user_tok
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
|
||||||
|
self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
def test_send_event(self):
|
||||||
|
"""Test that sending event as a user works.
|
||||||
|
"""
|
||||||
|
# Create a room.
|
||||||
|
room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
|
||||||
|
|
||||||
|
# Login in as the user
|
||||||
|
puppet_token = self._get_token()
|
||||||
|
|
||||||
|
# Test that sending works, and generates the event as the right user.
|
||||||
|
resp = self.helper.send_event(room_id, "com.example.test", tok=puppet_token)
|
||||||
|
event_id = resp["event_id"]
|
||||||
|
event = self.get_success(self.store.get_event(event_id))
|
||||||
|
self.assertEqual(event.sender, self.other_user)
|
||||||
|
|
||||||
|
def test_devices(self):
|
||||||
|
"""Tests that logging in as a user doesn't create a new device for them.
|
||||||
|
"""
|
||||||
|
# Login in as the user
|
||||||
|
self._get_token()
|
||||||
|
|
||||||
|
# Check that we don't see a new device in our devices list
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", "devices", b"{}", access_token=self.other_user_tok
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
# We should only see the one device (from the login in `prepare`)
|
||||||
|
self.assertEqual(len(channel.json_body["devices"]), 1)
|
||||||
|
|
||||||
|
def test_logout(self):
|
||||||
|
"""Test that calling `/logout` with the token works.
|
||||||
|
"""
|
||||||
|
# Login in as the user
|
||||||
|
puppet_token = self._get_token()
|
||||||
|
|
||||||
|
# Test that we can successfully make a request
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", "devices", b"{}", access_token=puppet_token
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
# Logout with the puppet token
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"POST", "logout", b"{}", access_token=puppet_token
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
# The puppet token should no longer work
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", "devices", b"{}", access_token=puppet_token
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
# .. but the real user's tokens should still work
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", "devices", b"{}", access_token=self.other_user_tok
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
def test_user_logout_all(self):
|
||||||
|
"""Tests that the target user calling `/logout/all` does *not* expire
|
||||||
|
the token.
|
||||||
|
"""
|
||||||
|
# Login in as the user
|
||||||
|
puppet_token = self._get_token()
|
||||||
|
|
||||||
|
# Test that we can successfully make a request
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", "devices", b"{}", access_token=puppet_token
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
# Logout all with the real user token
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"POST", "logout/all", b"{}", access_token=self.other_user_tok
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
# The puppet token should still work
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", "devices", b"{}", access_token=puppet_token
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
# .. but the real user's tokens shouldn't
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", "devices", b"{}", access_token=self.other_user_tok
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
def test_admin_logout_all(self):
|
||||||
|
"""Tests that the admin user calling `/logout/all` does expire the
|
||||||
|
token.
|
||||||
|
"""
|
||||||
|
# Login in as the user
|
||||||
|
puppet_token = self._get_token()
|
||||||
|
|
||||||
|
# Test that we can successfully make a request
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", "devices", b"{}", access_token=puppet_token
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
# Logout all with the admin user token
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"POST", "logout/all", b"{}", access_token=self.admin_user_tok
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
# The puppet token should no longer work
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", "devices", b"{}", access_token=puppet_token
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
# .. but the real user's tokens should still work
|
||||||
|
request, channel = self.make_request(
|
||||||
|
"GET", "devices", b"{}", access_token=self.other_user_tok
|
||||||
|
)
|
||||||
|
self.render(request)
|
||||||
|
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
|
||||||
|
|
||||||
|
@unittest.override_config(
|
||||||
|
{
|
||||||
|
"public_baseurl": "https://example.org/",
|
||||||
|
"user_consent": {
|
||||||
|
"version": "1.0",
|
||||||
|
"policy_name": "My Cool Privacy Policy",
|
||||||
|
"template_dir": "/",
|
||||||
|
"require_at_registration": True,
|
||||||
|
"block_events_error": "You should accept the policy",
|
||||||
|
},
|
||||||
|
"form_secret": "123secret",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
def test_consent(self):
|
||||||
|
"""Test that sending a message is not subject to the privacy policies.
|
||||||
|
"""
|
||||||
|
# Have the admin user accept the terms.
|
||||||
|
self.get_success(self.store.user_set_consent_version(self.admin_user, "1.0"))
|
||||||
|
|
||||||
|
# First, cheekily accept the terms and create a room
|
||||||
|
self.get_success(self.store.user_set_consent_version(self.other_user, "1.0"))
|
||||||
|
room_id = self.helper.create_room_as(self.other_user, tok=self.other_user_tok)
|
||||||
|
self.helper.send_event(room_id, "com.example.test", tok=self.other_user_tok)
|
||||||
|
|
||||||
|
# Now unaccept it and check that we can't send an event
|
||||||
|
self.get_success(self.store.user_set_consent_version(self.other_user, "0.0"))
|
||||||
|
self.helper.send_event(
|
||||||
|
room_id, "com.example.test", tok=self.other_user_tok, expect_code=403
|
||||||
|
)
|
||||||
|
|
||||||
|
# Login in as the user
|
||||||
|
puppet_token = self._get_token()
|
||||||
|
|
||||||
|
# Sending an event on their behalf should work fine
|
||||||
|
self.helper.send_event(room_id, "com.example.test", tok=puppet_token)
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{"limit_usage_by_mau": True, "max_mau_value": 1, "mau_trial_days": 0}
|
||||||
|
)
|
||||||
|
def test_mau_limit(self):
|
||||||
|
# Create a room as the admin user. This will bump the monthly active users to 1.
|
||||||
|
room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok)
|
||||||
|
|
||||||
|
# Trying to join as the other user should fail.
|
||||||
|
self.helper.join(
|
||||||
|
room_id, user=self.other_user, tok=self.other_user_tok, expect_code=403
|
||||||
|
)
|
||||||
|
|
||||||
|
# Logging in as the other user and joining a room should work, even
|
||||||
|
# though they should be denied.
|
||||||
|
puppet_token = self._get_token()
|
||||||
|
self.helper.join(room_id, user=self.other_user, tok=puppet_token)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import synapse.rest.admin
|
|||||||
from synapse.api.constants import EventTypes
|
from synapse.api.constants import EventTypes
|
||||||
from synapse.rest.client.v1 import login, room
|
from synapse.rest.client.v1 import login, room
|
||||||
from synapse.storage import prepare_database
|
from synapse.storage import prepare_database
|
||||||
from synapse.types import Requester, UserID
|
from synapse.types import UserID, create_requester
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
@@ -38,7 +38,7 @@ class CleanupExtremBackgroundUpdateStoreTestCase(HomeserverTestCase):
|
|||||||
|
|
||||||
# Create a test user and room
|
# Create a test user and room
|
||||||
self.user = UserID("alice", "test")
|
self.user = UserID("alice", "test")
|
||||||
self.requester = Requester(self.user, None, False, False, None, None)
|
self.requester = create_requester(self.user)
|
||||||
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
|
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
|
||||||
self.room_id = info["room_id"]
|
self.room_id = info["room_id"]
|
||||||
|
|
||||||
@@ -260,7 +260,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
|
|||||||
# Create a test user and room
|
# Create a test user and room
|
||||||
self.user = UserID.from_string(self.register_user("user1", "password"))
|
self.user = UserID.from_string(self.register_user("user1", "password"))
|
||||||
self.token1 = self.login("user1", "password")
|
self.token1 = self.login("user1", "password")
|
||||||
self.requester = Requester(self.user, None, False, False, None, None)
|
self.requester = create_requester(self.user)
|
||||||
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
|
info, _ = self.get_success(self.room_creator.create_room(self.requester, {}))
|
||||||
self.room_id = info["room_id"]
|
self.room_id = info["room_id"]
|
||||||
self.event_creator = homeserver.get_event_creation_handler()
|
self.event_creator = homeserver.get_event_creation_handler()
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from synapse.metrics import REGISTRY, generate_latest
|
from synapse.metrics import REGISTRY, generate_latest
|
||||||
from synapse.types import Requester, UserID
|
from synapse.types import UserID, create_requester
|
||||||
|
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
@@ -27,7 +27,7 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
|
|||||||
room_creator = self.hs.get_room_creation_handler()
|
room_creator = self.hs.get_room_creation_handler()
|
||||||
|
|
||||||
user = UserID("alice", "test")
|
user = UserID("alice", "test")
|
||||||
requester = Requester(user, None, False, False, None, None)
|
requester = create_requester(user)
|
||||||
|
|
||||||
# Real events, forward extremities
|
# Real events, forward extremities
|
||||||
events = [(3, 2), (6, 2), (4, 6)]
|
events = [(3, 2), (6, 2), (4, 6)]
|
||||||
|
|||||||
@@ -69,11 +69,9 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
|||||||
self.store.get_user_by_access_token(self.tokens[1])
|
self.store.get_user_by_access_token(self.tokens[1])
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertDictContainsSubset(
|
self.assertEqual(result.user_id, self.user_id)
|
||||||
{"name": self.user_id, "device_id": self.device_id}, result
|
self.assertEqual(result.device_id, self.device_id)
|
||||||
)
|
self.assertIsNotNone(result.token_id)
|
||||||
|
|
||||||
self.assertTrue("token_id" in result)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_user_delete_access_tokens(self):
|
def test_user_delete_access_tokens(self):
|
||||||
@@ -105,7 +103,7 @@ class RegistrationStoreTestCase(unittest.TestCase):
|
|||||||
user = yield defer.ensureDeferred(
|
user = yield defer.ensureDeferred(
|
||||||
self.store.get_user_by_access_token(self.tokens[0])
|
self.store.get_user_by_access_token(self.tokens[0])
|
||||||
)
|
)
|
||||||
self.assertEqual(self.user_id, user["name"])
|
self.assertEqual(self.user_id, user.user_id)
|
||||||
|
|
||||||
# now delete the rest
|
# now delete the rest
|
||||||
yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
|
yield defer.ensureDeferred(self.store.user_delete_access_tokens(self.user_id))
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from unittest.mock import Mock
|
|||||||
from synapse.api.constants import Membership
|
from synapse.api.constants import Membership
|
||||||
from synapse.rest.admin import register_servlets_for_client_rest_resource
|
from synapse.rest.admin import register_servlets_for_client_rest_resource
|
||||||
from synapse.rest.client.v1 import login, room
|
from synapse.rest.client.v1 import login, room
|
||||||
from synapse.types import Requester, UserID
|
from synapse.types import UserID, create_requester
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.test_utils import event_injection
|
from tests.test_utils import event_injection
|
||||||
@@ -187,7 +187,7 @@ class CurrentStateMembershipUpdateTestCase(unittest.HomeserverTestCase):
|
|||||||
|
|
||||||
# Now let's create a room, which will insert a membership
|
# Now let's create a room, which will insert a membership
|
||||||
user = UserID("alice", "test")
|
user = UserID("alice", "test")
|
||||||
requester = Requester(user, None, False, False, None, None)
|
requester = create_requester(user)
|
||||||
self.get_success(self.room_creator.create_room(requester, {}))
|
self.get_success(self.room_creator.create_room(requester, {}))
|
||||||
|
|
||||||
# Register the background update to run again.
|
# Register the background update to run again.
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from twisted.internet.defer import succeed
|
|||||||
from synapse.api.errors import FederationError
|
from synapse.api.errors import FederationError
|
||||||
from synapse.events import make_event_from_dict
|
from synapse.events import make_event_from_dict
|
||||||
from synapse.logging.context import LoggingContext
|
from synapse.logging.context import LoggingContext
|
||||||
from synapse.types import Requester, UserID
|
from synapse.types import UserID, create_requester
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
from synapse.util.retryutils import NotRetryingDestination
|
from synapse.util.retryutils import NotRetryingDestination
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
user_id = UserID("us", "test")
|
user_id = UserID("us", "test")
|
||||||
our_user = Requester(user_id, None, False, False, None, None)
|
our_user = create_requester(user_id)
|
||||||
room_creator = self.homeserver.get_room_creation_handler()
|
room_creator = self.homeserver.get_room_creation_handler()
|
||||||
self.room_id = self.get_success(
|
self.room_id = self.get_success(
|
||||||
room_creator.create_room(
|
room_creator.create_room(
|
||||||
|
|||||||
@@ -169,6 +169,7 @@ class StateTestCase(unittest.TestCase):
|
|||||||
"get_state_handler",
|
"get_state_handler",
|
||||||
"get_clock",
|
"get_clock",
|
||||||
"get_state_resolution_handler",
|
"get_state_resolution_handler",
|
||||||
|
"hostname",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
hs.config = default_config("tesths", True)
|
hs.config = default_config("tesths", True)
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ from synapse.logging.context import (
|
|||||||
set_current_context,
|
set_current_context,
|
||||||
)
|
)
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
from synapse.types import Requester, UserID, create_requester
|
from synapse.types import UserID, create_requester
|
||||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||||
|
|
||||||
from tests.server import (
|
from tests.server import (
|
||||||
@@ -627,7 +627,7 @@ class HomeserverTestCase(TestCase):
|
|||||||
"""
|
"""
|
||||||
event_creator = self.hs.get_event_creation_handler()
|
event_creator = self.hs.get_event_creation_handler()
|
||||||
secrets = self.hs.get_secrets()
|
secrets = self.hs.get_secrets()
|
||||||
requester = Requester(user, None, False, False, None, None)
|
requester = create_requester(user)
|
||||||
|
|
||||||
event, context = self.get_success(
|
event, context = self.get_success(
|
||||||
event_creator.create_event(
|
event_creator.create_event(
|
||||||
|
|||||||
Reference in New Issue
Block a user