Speed up device deletion (#18602)

This is to handle the case of deleting lots of "bot" devices at once.

Reviewable commit-by-commit

---------

Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
This commit is contained in:
Erik Johnston
2025-06-30 11:48:57 +01:00
committed by GitHub
parent b35c6483d5
commit 3878699df7
7 changed files with 225 additions and 84 deletions

1
changelog.d/18602.misc Normal file
View File

@@ -0,0 +1 @@
Speed up bulk device deletion.

View File

@@ -76,7 +76,7 @@ from synapse.storage.databases.main.registration import (
LoginTokenLookupResult,
LoginTokenReused,
)
from synapse.types import JsonDict, Requester, UserID
from synapse.types import JsonDict, Requester, StrCollection, UserID
from synapse.util import stringutils as stringutils
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
from synapse.util.msisdn import phone_number_to_msisdn
@@ -1547,6 +1547,31 @@ class AuthHandler:
user_id, (token_id for _, token_id, _ in tokens_and_devices)
)
async def delete_access_tokens_for_devices(
self,
user_id: str,
device_ids: StrCollection,
) -> None:
"""Invalidate access tokens for the devices
Args:
user_id: ID of user the tokens belong to
device_ids: ID of device the tokens are associated with.
If None, tokens associated with any device (or no device) will
be deleted
"""
tokens_and_devices = await self.store.user_delete_access_tokens_for_devices(
user_id,
device_ids,
)
# see if any modules want to know about this
if self.password_auth_provider.on_logged_out_callbacks:
for token, _, device_id in tokens_and_devices:
await self.password_auth_provider.on_logged_out(
user_id=user_id, device_id=device_id, access_token=token
)
async def add_threepid(
self, user_id: str, medium: str, address: str, validated_at: int
) -> None:

View File

@@ -671,12 +671,12 @@ class DeviceHandler(DeviceWorkerHandler):
except_device_id: optional device id which should not be deleted
"""
device_map = await self.store.get_devices_by_user(user_id)
device_ids = list(device_map)
if except_device_id is not None:
device_ids = [d for d in device_ids if d != except_device_id]
await self.delete_devices(user_id, device_ids)
device_map.pop(except_device_id, None)
user_device_ids = device_map.keys()
await self.delete_devices(user_id, user_device_ids)
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
async def delete_devices(self, user_id: str, device_ids: StrCollection) -> None:
"""Delete several devices
Args:
@@ -695,17 +695,10 @@ class DeviceHandler(DeviceWorkerHandler):
else:
raise
# Delete data specific to each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
await self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
)
await self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
if self.hs.config.experimental.msc3890_enabled:
# Delete data specific to each device. Not optimised as its an
# experimental MSC.
if self.hs.config.experimental.msc3890_enabled:
for device_id in device_ids:
# Remove any local notification settings for this device in accordance
# with MSC3890.
await self._account_data_handler.remove_account_data_for_user(
@@ -713,6 +706,13 @@ class DeviceHandler(DeviceWorkerHandler):
f"org.matrix.msc3890.local_notification_settings.{device_id}",
)
# If we're deleting a lot of devices, a bunch of them may not have any
# to-device messages queued up. We filter those out to avoid scheduling
# unnecessary tasks.
devices_with_messages = await self.store.get_devices_with_messages(
user_id, device_ids
)
for device_id in devices_with_messages:
# Delete device messages asynchronously and in batches using the task scheduler
# We specify an upper stream id to avoid deleting non delivered messages
# if an user re-uses a device ID.
@@ -726,6 +726,10 @@ class DeviceHandler(DeviceWorkerHandler):
},
)
await self._auth_handler.delete_access_tokens_for_devices(
user_id, device_ids=device_ids
)
# Pushers are deleted after `delete_access_tokens_for_user` is called so that
# modules using `on_logged_out` hook can use them if needed.
await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids)
@@ -819,10 +823,11 @@ class DeviceHandler(DeviceWorkerHandler):
# This should only happen if there are no updates, so we bail.
return
for device_id in device_ids:
logger.debug(
"Notifying about update %r/%r, ID: %r", user_id, device_id, position
)
if logger.isEnabledFor(logging.DEBUG):
for device_id in device_ids:
logger.debug(
"Notifying about update %r/%r, ID: %r", user_id, device_id, position
)
# specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms.
@@ -922,9 +927,6 @@ class DeviceHandler(DeviceWorkerHandler):
# can't call self.delete_device because that will clobber the
# access token so call the storage layer directly
await self.store.delete_devices(user_id, [old_device_id])
await self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=old_device_id
)
# tell everyone that the old device is gone and that the dehydrated
# device has a new display name
@@ -946,7 +948,6 @@ class DeviceHandler(DeviceWorkerHandler):
raise errors.NotFoundError()
await self.delete_devices(user_id, [device_id])
await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
@wrap_as_background_process("_handle_new_device_update_async")
async def _handle_new_device_update_async(self) -> None:

View File

@@ -52,10 +52,11 @@ from synapse.storage.database import (
make_in_list_sql_clause,
)
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict
from synapse.types import JsonDict, StrCollection
from synapse.util import Duration, json_encoder
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import parse_and_validate_server_name
if TYPE_CHECKING:
@@ -1027,6 +1028,40 @@ class DeviceInboxWorkerStore(SQLBaseStore):
# loop first time we run this.
self._clock.sleep(1)
async def get_devices_with_messages(
self, user_id: str, device_ids: StrCollection
) -> StrCollection:
"""Get the matching device IDs that have messages in the device inbox."""
def get_devices_with_messages_txn(
txn: LoggingTransaction,
batch_device_ids: StrCollection,
) -> StrCollection:
clause, args = make_in_list_sql_clause(
self.database_engine, "device_id", batch_device_ids
)
sql = f"""
SELECT DISTINCT device_id FROM device_inbox
WHERE {clause} AND user_id = ?
"""
args.append(user_id)
txn.execute(sql, args)
return {row[0] for row in txn}
results: Set[str] = set()
for batch_device_ids in batch_iter(device_ids, 1000):
batch_results = await self.db_pool.runInteraction(
"get_devices_with_messages",
get_devices_with_messages_txn,
batch_device_ids,
# We don't need to run in a transaction as it's a single query
db_autocommit=True,
)
results.update(batch_results)
return results
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"

View File

@@ -282,7 +282,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
"count_devices_by_users", count_devices_by_users_txn, user_ids
)
@cached()
@cached(tree=True)
async def get_device(
self, user_id: str, device_id: str
) -> Optional[Mapping[str, Any]]:
@@ -1861,7 +1861,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
)
raise StoreError(500, "Problem storing device.")
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
async def delete_devices(self, user_id: str, device_ids: StrCollection) -> None:
"""Deletes several devices.
Args:
@@ -1885,11 +1885,49 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
values=device_ids,
keyvalues={"user_id": user_id},
)
self._invalidate_cache_and_stream_bulk(
txn, self.get_device, [(user_id, device_id) for device_id in device_ids]
# Also delete associated e2e keys.
self.db_pool.simple_delete_many_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id},
column="device_id",
values=device_ids,
)
self.db_pool.simple_delete_many_txn(
txn,
table="e2e_one_time_keys_json",
keyvalues={"user_id": user_id},
column="device_id",
values=device_ids,
)
self.db_pool.simple_delete_many_txn(
txn,
table="dehydrated_devices",
keyvalues={"user_id": user_id},
column="device_id",
values=device_ids,
)
self.db_pool.simple_delete_many_txn(
txn,
table="e2e_fallback_keys_json",
keyvalues={"user_id": user_id},
column="device_id",
values=device_ids,
)
for batch in batch_iter(device_ids, 100):
# We're bulk deleting potentially many devices at once, so
# let's not invalidate the cache for each device individually.
# Instead, we will invalidate the cache for the user as a whole.
self._invalidate_cache_and_stream(txn, self.get_device, (user_id,))
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id,)
)
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id,)
)
for batch in batch_iter(device_ids, 1000):
await self.db_pool.runInteraction(
"delete_devices", _delete_devices_txn, batch
)
@@ -2061,32 +2099,36 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
context = get_active_span_text_map()
def add_device_changes_txn(
txn: LoggingTransaction, stream_ids: List[int]
txn: LoggingTransaction,
batch_device_ids: StrCollection,
stream_ids: List[int],
) -> None:
self._add_device_change_to_stream_txn(
txn,
user_id,
device_ids,
batch_device_ids,
stream_ids,
)
self._add_device_outbound_room_poke_txn(
txn,
user_id,
device_ids,
batch_device_ids,
room_ids,
stream_ids,
context,
)
async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
"add_device_change_to_stream",
add_device_changes_txn,
stream_ids,
)
for batch_device_ids in batch_iter(device_ids, 1000):
async with self._device_list_id_gen.get_next_mult(
len(device_ids)
) as stream_ids:
await self.db_pool.runInteraction(
"add_device_change_to_stream",
add_device_changes_txn,
batch_device_ids,
stream_ids,
)
return stream_ids[-1]

View File

@@ -593,7 +593,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
@cached(max_entries=10000)
@cached(max_entries=10000, tree=True)
async def count_e2e_one_time_keys(
self, user_id: str, device_id: str
) -> Mapping[str, int]:
@@ -808,7 +808,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
},
)
@cached(max_entries=10000)
@cached(max_entries=10000, tree=True)
async def get_e2e_unused_fallback_key_types(
self, user_id: str, device_id: str
) -> Sequence[str]:
@@ -1632,46 +1632,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
log_kv({"message": "Device keys stored."})
return True
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
log_kv(
{
"message": "Deleting keys for device",
"device_id": device_id,
"user_id": user_id,
}
)
self.db_pool.simple_delete_txn(
txn,
table="e2e_device_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
self.db_pool.simple_delete_txn(
txn,
table="e2e_one_time_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
self.db_pool.simple_delete_txn(
txn,
table="dehydrated_devices",
keyvalues={"user_id": user_id, "device_id": device_id},
)
self.db_pool.simple_delete_txn(
txn,
table="e2e_fallback_keys_json",
keyvalues={"user_id": user_id, "device_id": device_id},
)
self._invalidate_cache_and_stream(
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
)
await self.db_pool.runInteraction(
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
)
def _set_e2e_cross_signing_key_txn(
self,
txn: LoggingTransaction,

View File

@@ -40,14 +40,16 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import IdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, UserID, UserInfo
from synapse.types import JsonDict, StrCollection, UserID, UserInfo
from synapse.util.caches.descriptors import cached
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -2801,6 +2803,81 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
async def user_delete_access_tokens_for_devices(
self,
user_id: str,
device_ids: StrCollection,
) -> List[Tuple[str, int, Optional[str]]]:
"""
Invalidate access and refresh tokens belonging to a user
Args:
user_id: ID of user the tokens belong to
device_ids: The devices to delete tokens for.
Returns:
A tuple of (token, token id, device id) for each of the deleted tokens
"""
def user_delete_access_tokens_for_devices_txn(
txn: LoggingTransaction, batch_device_ids: StrCollection
) -> List[Tuple[str, int, Optional[str]]]:
self.db_pool.simple_delete_many_txn(
txn,
table="refresh_tokens",
keyvalues={"user_id": user_id},
column="device_id",
values=batch_device_ids,
)
clause, args = make_in_list_sql_clause(
txn.database_engine, "device_id", batch_device_ids
)
args.append(user_id)
if self.database_engine.supports_returning:
sql = f"""
DELETE FROM access_tokens
WHERE {clause} AND user_id = ?
RETURNING token, id, device_id
"""
txn.execute(sql, args)
tokens_and_devices = txn.fetchall()
else:
tokens_and_devices = self.db_pool.simple_select_many_txn(
txn,
table="access_tokens",
column="device_id",
iterable=batch_device_ids,
keyvalues={"user_id": user_id},
retcols=("token", "id", "device_id"),
)
self.db_pool.simple_delete_many_txn(
txn,
table="access_tokens",
keyvalues={"user_id": user_id},
column="device_id",
values=batch_device_ids,
)
self._invalidate_cache_and_stream_bulk(
txn,
self.get_user_by_access_token,
[(t[0],) for t in tokens_and_devices],
)
return tokens_and_devices
results = []
for batch_device_ids in batch_iter(device_ids, 1000):
tokens_and_devices = await self.db_pool.runInteraction(
"user_delete_access_tokens_for_devices",
user_delete_access_tokens_for_devices_txn,
batch_device_ids,
)
results.extend(tokens_and_devices)
return results
async def delete_access_token(self, access_token: str) -> None:
def f(txn: LoggingTransaction) -> None:
self.db_pool.simple_delete_one_txn(