Make user_type extensible and allow default user_type to be set (#18456)

This commit is contained in:
Hugh Nimmo-Smith
2025-06-03 12:34:40 +01:00
committed by GitHub
parent fae72f181b
commit a4d8da7a1b
12 changed files with 306 additions and 47 deletions

View File

@@ -0,0 +1 @@
Support configuration of default and extra user types.

View File

@@ -163,7 +163,8 @@ Body parameters:
- `locked` - **bool**, optional. If unspecified, locked state will be left unchanged.
- `user_type` - **string** or null, optional. If not provided, the user type will be
not be changed. If `null` is given, the user type will be cleared.
Other allowed options are: `bot` and `support`.
Other allowed options are: `bot` and `support` and any extra values defined in the homserver
[configuration](../usage/configuration/config_documentation.md#user_types).
## List Accounts
### List Accounts (V2)

View File

@@ -762,6 +762,24 @@ Example configuration:
max_event_delay_duration: 24h
```
---
### `user_types`
Configuration settings related to the user types feature.
This setting has the following sub-options:
* `default_user_type`: The default user type to use for registering new users when no value has been specified.
Defaults to none.
* `extra_user_types`: Array of additional user types to allow. These are treated as real users. Defaults to [].
Example configuration:
```yaml
user_types:
default_user_type: "custom"
extra_user_types:
- "custom"
- "custom2"
```
## Homeserver blocking
Useful options for Synapse admins.

View File

@@ -185,12 +185,18 @@ ServerNoticeLimitReached: Final = "m.server_notice.usage_limit_reached"
class UserTypes:
"""Allows for user type specific behaviour. With the benefit of hindsight
'admin' and 'guest' users should also be UserTypes. Normal users are type None
'admin' and 'guest' users should also be UserTypes. Extra user types can be
added in the configuration. Normal users are type None or one of the extra
user types (if configured).
"""
SUPPORT: Final = "support"
BOT: Final = "bot"
ALL_USER_TYPES: Final = (SUPPORT, BOT)
ALL_BUILTIN_USER_TYPES: Final = (SUPPORT, BOT)
"""
The user types that are built-in to Synapse. Extra user types can be
added in the configuration.
"""
class RelationTypes:

View File

@@ -59,6 +59,7 @@ from synapse.config import ( # noqa: F401
tls,
tracer,
user_directory,
user_types,
voip,
workers,
)
@@ -122,6 +123,7 @@ class RootConfig:
retention: retention.RetentionConfig
background_updates: background_updates.BackgroundUpdateConfig
auto_accept_invites: auto_accept_invites.AutoAcceptInvitesConfig
user_types: user_types.UserTypesConfig
config_classes: List[Type["Config"]] = ...
config_files: List[str]

View File

@@ -59,6 +59,7 @@ from .third_party_event_rules import ThirdPartyRulesConfig
from .tls import TlsConfig
from .tracer import TracerConfig
from .user_directory import UserDirectoryConfig
from .user_types import UserTypesConfig
from .voip import VoipConfig
from .workers import WorkerConfig
@@ -107,4 +108,5 @@ class HomeServerConfig(RootConfig):
ExperimentalConfig,
BackgroundUpdateConfig,
AutoAcceptInvitesConfig,
UserTypesConfig,
]

View File

@@ -0,0 +1,44 @@
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright (C) 2025 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
from typing import Any, List, Optional
from synapse.api.constants import UserTypes
from synapse.types import JsonDict
from ._base import Config, ConfigError
class UserTypesConfig(Config):
section = "user_types"
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
user_types: JsonDict = config.get("user_types", {})
self.default_user_type: Optional[str] = user_types.get(
"default_user_type", None
)
self.extra_user_types: List[str] = user_types.get("extra_user_types", [])
all_user_types: List[str] = []
all_user_types.extend(UserTypes.ALL_BUILTIN_USER_TYPES)
all_user_types.extend(self.extra_user_types)
self.all_user_types = all_user_types
if self.default_user_type is not None:
if self.default_user_type not in all_user_types:
raise ConfigError(
f"Default user type {self.default_user_type} is not in the list of all user types: {all_user_types}"
)

View File

@@ -115,6 +115,7 @@ class RegistrationHandler:
self._user_consent_version = self.hs.config.consent.user_consent_version
self._server_notices_mxid = hs.config.servernotices.server_notices_mxid
self._server_name = hs.hostname
self._user_types_config = hs.config.user_types
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
@@ -306,6 +307,9 @@ class RegistrationHandler:
elif default_display_name is None:
default_display_name = localpart
if user_type is None:
user_type = self._user_types_config.default_user_type
await self.register_with_store(
user_id=user_id,
password_hash=password_hash,

View File

@@ -28,7 +28,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import attr
from synapse._pydantic_compat import StrictBool, StrictInt, StrictStr
from synapse.api.constants import Direction, UserTypes
from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import (
RestServlet,
@@ -230,6 +230,7 @@ class UserRestServletV2(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.pusher_pool = hs.get_pusherpool()
self._msc3866_enabled = hs.config.experimental.msc3866.enabled
self._all_user_types = hs.config.user_types.all_user_types
async def on_GET(
self, request: SynapseRequest, user_id: str
@@ -277,7 +278,7 @@ class UserRestServletV2(RestServlet):
assert_params_in_dict(external_id, ["auth_provider", "external_id"])
user_type = body.get("user_type", None)
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
if user_type is not None and user_type not in self._all_user_types:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type")
set_admin_to = body.get("admin", False)
@@ -524,6 +525,7 @@ class UserRegisterServlet(RestServlet):
self.reactor = hs.get_reactor()
self.nonces: Dict[str, int] = {}
self.hs = hs
self._all_user_types = hs.config.user_types.all_user_types
def _clear_old_nonces(self) -> None:
"""
@@ -605,7 +607,7 @@ class UserRegisterServlet(RestServlet):
user_type = body.get("user_type", None)
displayname = body.get("displayname", None)
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
if user_type is not None and user_type not in self._all_user_types:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid user type")
if "mac" not in body:

View File

@@ -583,7 +583,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
await self.db_pool.runInteraction("set_shadow_banned", set_shadow_banned_txn)
async def set_user_type(self, user: UserID, user_type: Optional[UserTypes]) -> None:
async def set_user_type(
self, user: UserID, user_type: Optional[Union[UserTypes, str]]
) -> None:
"""Sets the user type.
Args:
@@ -683,7 +685,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
retcol="user_type",
allow_none=True,
)
return res is None
return res is None or res not in [UserTypes.BOT, UserTypes.SUPPORT]
def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
res = self.db_pool.simple_select_one_onecol_txn(
@@ -959,10 +961,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
return await self.db_pool.runInteraction("count_users", _count_users)
async def count_real_users(self) -> int:
"""Counts all users without a special user_type registered on the homeserver."""
"""Counts all users without the bot or support user_types registered on the homeserver."""
def _count_users(txn: LoggingTransaction) -> int:
txn.execute("SELECT COUNT(*) FROM users where user_type is null")
txn.execute(
f"SELECT COUNT(*) FROM users WHERE user_type IS NULL OR user_type NOT IN ('{UserTypes.BOT}', '{UserTypes.SUPPORT}')"
)
row = txn.fetchone()
assert row is not None
return row[0]
@@ -2545,7 +2549,8 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
the user, setting their displayname to the given value
admin: is an admin user?
user_type: type of user. One of the values from api.constants.UserTypes,
or None for a normal user.
a custom value set in the configuration file, or None for a normal
user.
shadow_banned: Whether the user is shadow-banned, i.e. they may be
told their requests succeeded but we ignore them.
approved: Whether to consider the user has already been approved by an

View File

@@ -738,6 +738,41 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
self.handler.register_user(localpart="bobflimflob", auth_provider_id="saml")
)
def test_register_default_user_type(self) -> None:
"""Test that the default user type is none when registering a user."""
user_id = self.get_success(self.handler.register_user(localpart="user"))
user_info = self.get_success(self.store.get_user_by_id(user_id))
assert user_info is not None
self.assertEqual(user_info.user_type, None)
def test_register_extra_user_types_valid(self) -> None:
"""
Test that the specified user type is set correctly when registering a user.
n.b. No validation is done on the user type, so this test
is only to ensure that the user type can be set to any value.
"""
user_id = self.get_success(
self.handler.register_user(localpart="user", user_type="anyvalue")
)
user_info = self.get_success(self.store.get_user_by_id(user_id))
assert user_info is not None
self.assertEqual(user_info.user_type, "anyvalue")
@override_config(
{
"user_types": {
"extra_user_types": ["extra1", "extra2"],
"default_user_type": "extra1",
}
}
)
def test_register_extra_user_types_with_default(self) -> None:
"""Test that the default_user_type in config is set correctly when registering a user."""
user_id = self.get_success(self.handler.register_user(localpart="user"))
user_info = self.get_success(self.store.get_user_by_id(user_id))
assert user_info is not None
self.assertEqual(user_info.user_type, "extra1")
async def get_or_create_user(
self,
requester: Requester,

View File

@@ -328,6 +328,61 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
@override_config(
{
"user_types": {
"extra_user_types": ["extra1", "extra2"],
}
}
)
def test_extra_user_type(self) -> None:
"""
Check that the extra user type can be used when registering a user.
"""
def nonce_mac(user_type: str) -> tuple[str, str]:
"""
Get a nonce and the expected HMAC for that nonce.
"""
channel = self.make_request("GET", self.url)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
want_mac.update(
nonce.encode("ascii")
+ b"\x00alice\x00abc123\x00notadmin\x00"
+ user_type.encode("ascii")
)
want_mac_str = want_mac.hexdigest()
return nonce, want_mac_str
nonce, mac = nonce_mac("extra1")
# Valid user_type
body = {
"nonce": nonce,
"username": "alice",
"password": "abc123",
"user_type": "extra1",
"mac": mac,
}
channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
nonce, mac = nonce_mac("extra3")
# Invalid user_type
body = {
"nonce": nonce,
"username": "alice",
"password": "abc123",
"user_type": "extra3",
"mac": mac,
}
channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
def test_displayname(self) -> None:
"""
Test that displayname of new user is set
@@ -1186,6 +1241,80 @@ class UsersListTestCase(unittest.HomeserverTestCase):
not_user_types=["custom"],
)
@override_config(
{
"user_types": {
"extra_user_types": ["extra1", "extra2"],
}
}
)
def test_filter_not_user_types_with_extra(self) -> None:
"""Tests that the endpoint handles the not_user_types param when extra_user_types are configured"""
regular_user_id = self.register_user("normalo", "secret")
extra1_user_id = self.register_user("extra1", "secret")
self.make_request(
"PUT",
"/_synapse/admin/v2/users/" + urllib.parse.quote(extra1_user_id),
{"user_type": "extra1"},
access_token=self.admin_user_tok,
)
def test_user_type(
expected_user_ids: List[str], not_user_types: Optional[List[str]] = None
) -> None:
"""Runs a test for the not_user_types param
Args:
expected_user_ids: Ids of the users that are expected to be returned
not_user_types: List of values for the not_user_types param
"""
user_type_query = ""
if not_user_types is not None:
user_type_query = "&".join(
[f"not_user_type={u}" for u in not_user_types]
)
test_url = f"{self.url}?{user_type_query}"
channel = self.make_request(
"GET",
test_url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code)
self.assertEqual(channel.json_body["total"], len(expected_user_ids))
self.assertEqual(
expected_user_ids,
[u["name"] for u in channel.json_body["users"]],
)
# Request without user_types → all users expected
test_user_type([self.admin_user, extra1_user_id, regular_user_id])
# Request and exclude extra1 user type
test_user_type(
[self.admin_user, regular_user_id],
not_user_types=["extra1"],
)
# Request and exclude extra1 and extra2 user types
test_user_type(
[self.admin_user, regular_user_id],
not_user_types=["extra1", "extra2"],
)
# Request and exclude empty user types → only expected the extra1 user
test_user_type([extra1_user_id], not_user_types=[""])
# Request and exclude an unregistered type → expect all users
test_user_type(
[self.admin_user, extra1_user_id, regular_user_id],
not_user_types=["extra3"],
)
def test_erasure_status(self) -> None:
# Create a new user.
user_id = self.register_user("eraseme", "eraseme")
@@ -2977,56 +3106,66 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual("@user:test", channel.json_body["name"])
self.assertTrue(channel.json_body["admin"])
def set_user_type(self, user_type: Optional[str]) -> None:
# Set to user_type
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"user_type": user_type},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(user_type, channel.json_body["user_type"])
# Get user
channel = self.make_request(
"GET",
self.url_other_user,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(user_type, channel.json_body["user_type"])
def test_set_user_type(self) -> None:
"""
Test changing user type.
"""
# Set to support type
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"user_type": UserTypes.SUPPORT},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
# Get user
channel = self.make_request(
"GET",
self.url_other_user,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertEqual(UserTypes.SUPPORT, channel.json_body["user_type"])
self.set_user_type(UserTypes.SUPPORT)
# Change back to a regular user
self.set_user_type(None)
@override_config({"user_types": {"extra_user_types": ["extra1", "extra2"]}})
def test_set_user_type_with_extras(self) -> None:
"""
Test changing user type with extra_user_types configured.
"""
# Check that we can still set to support type
self.set_user_type(UserTypes.SUPPORT)
# Check that we can set to an extra user type
self.set_user_type("extra2")
# Change back to a regular user
self.set_user_type(None)
# Try setting to invalid type
channel = self.make_request(
"PUT",
self.url_other_user,
access_token=self.admin_user_tok,
content={"user_type": None},
content={"user_type": "extra3"},
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"])
# Get user
channel = self.make_request(
"GET",
self.url_other_user,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@user:test", channel.json_body["name"])
self.assertIsNone(channel.json_body["user_type"])
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
def test_accidental_deactivation_prevention(self) -> None:
"""