Speed up device deletion (#18602)
This is to handle the case of deleting lots of "bot" devices at once. Reviewable commit-by-commit --------- Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com>
This commit is contained in:
1
changelog.d/18602.misc
Normal file
1
changelog.d/18602.misc
Normal file
@@ -0,0 +1 @@
|
||||
Speed up bulk device deletion.
|
||||
@@ -76,7 +76,7 @@ from synapse.storage.databases.main.registration import (
|
||||
LoginTokenLookupResult,
|
||||
LoginTokenReused,
|
||||
)
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
from synapse.types import JsonDict, Requester, StrCollection, UserID
|
||||
from synapse.util import stringutils as stringutils
|
||||
from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
@@ -1547,6 +1547,31 @@ class AuthHandler:
|
||||
user_id, (token_id for _, token_id, _ in tokens_and_devices)
|
||||
)
|
||||
|
||||
async def delete_access_tokens_for_devices(
|
||||
self,
|
||||
user_id: str,
|
||||
device_ids: StrCollection,
|
||||
) -> None:
|
||||
"""Invalidate access tokens for the devices
|
||||
|
||||
Args:
|
||||
user_id: ID of user the tokens belong to
|
||||
device_ids: ID of device the tokens are associated with.
|
||||
If None, tokens associated with any device (or no device) will
|
||||
be deleted
|
||||
"""
|
||||
tokens_and_devices = await self.store.user_delete_access_tokens_for_devices(
|
||||
user_id,
|
||||
device_ids,
|
||||
)
|
||||
|
||||
# see if any modules want to know about this
|
||||
if self.password_auth_provider.on_logged_out_callbacks:
|
||||
for token, _, device_id in tokens_and_devices:
|
||||
await self.password_auth_provider.on_logged_out(
|
||||
user_id=user_id, device_id=device_id, access_token=token
|
||||
)
|
||||
|
||||
async def add_threepid(
|
||||
self, user_id: str, medium: str, address: str, validated_at: int
|
||||
) -> None:
|
||||
|
||||
@@ -671,12 +671,12 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
except_device_id: optional device id which should not be deleted
|
||||
"""
|
||||
device_map = await self.store.get_devices_by_user(user_id)
|
||||
device_ids = list(device_map)
|
||||
if except_device_id is not None:
|
||||
device_ids = [d for d in device_ids if d != except_device_id]
|
||||
await self.delete_devices(user_id, device_ids)
|
||||
device_map.pop(except_device_id, None)
|
||||
user_device_ids = device_map.keys()
|
||||
await self.delete_devices(user_id, user_device_ids)
|
||||
|
||||
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
|
||||
async def delete_devices(self, user_id: str, device_ids: StrCollection) -> None:
|
||||
"""Delete several devices
|
||||
|
||||
Args:
|
||||
@@ -695,17 +695,10 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
else:
|
||||
raise
|
||||
|
||||
# Delete data specific to each device. Not optimised as it is not
|
||||
# considered as part of a critical path.
|
||||
for device_id in device_ids:
|
||||
await self._auth_handler.delete_access_tokens_for_user(
|
||||
user_id, device_id=device_id
|
||||
)
|
||||
await self.store.delete_e2e_keys_by_device(
|
||||
user_id=user_id, device_id=device_id
|
||||
)
|
||||
|
||||
if self.hs.config.experimental.msc3890_enabled:
|
||||
# Delete data specific to each device. Not optimised as its an
|
||||
# experimental MSC.
|
||||
if self.hs.config.experimental.msc3890_enabled:
|
||||
for device_id in device_ids:
|
||||
# Remove any local notification settings for this device in accordance
|
||||
# with MSC3890.
|
||||
await self._account_data_handler.remove_account_data_for_user(
|
||||
@@ -713,6 +706,13 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
f"org.matrix.msc3890.local_notification_settings.{device_id}",
|
||||
)
|
||||
|
||||
# If we're deleting a lot of devices, a bunch of them may not have any
|
||||
# to-device messages queued up. We filter those out to avoid scheduling
|
||||
# unnecessary tasks.
|
||||
devices_with_messages = await self.store.get_devices_with_messages(
|
||||
user_id, device_ids
|
||||
)
|
||||
for device_id in devices_with_messages:
|
||||
# Delete device messages asynchronously and in batches using the task scheduler
|
||||
# We specify an upper stream id to avoid deleting non delivered messages
|
||||
# if an user re-uses a device ID.
|
||||
@@ -726,6 +726,10 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
},
|
||||
)
|
||||
|
||||
await self._auth_handler.delete_access_tokens_for_devices(
|
||||
user_id, device_ids=device_ids
|
||||
)
|
||||
|
||||
# Pushers are deleted after `delete_access_tokens_for_user` is called so that
|
||||
# modules using `on_logged_out` hook can use them if needed.
|
||||
await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids)
|
||||
@@ -819,10 +823,11 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
# This should only happen if there are no updates, so we bail.
|
||||
return
|
||||
|
||||
for device_id in device_ids:
|
||||
logger.debug(
|
||||
"Notifying about update %r/%r, ID: %r", user_id, device_id, position
|
||||
)
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
for device_id in device_ids:
|
||||
logger.debug(
|
||||
"Notifying about update %r/%r, ID: %r", user_id, device_id, position
|
||||
)
|
||||
|
||||
# specify the user ID too since the user should always get their own device list
|
||||
# updates, even if they aren't in any rooms.
|
||||
@@ -922,9 +927,6 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
# can't call self.delete_device because that will clobber the
|
||||
# access token so call the storage layer directly
|
||||
await self.store.delete_devices(user_id, [old_device_id])
|
||||
await self.store.delete_e2e_keys_by_device(
|
||||
user_id=user_id, device_id=old_device_id
|
||||
)
|
||||
|
||||
# tell everyone that the old device is gone and that the dehydrated
|
||||
# device has a new display name
|
||||
@@ -946,7 +948,6 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
raise errors.NotFoundError()
|
||||
|
||||
await self.delete_devices(user_id, [device_id])
|
||||
await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
|
||||
|
||||
@wrap_as_background_process("_handle_new_device_update_async")
|
||||
async def _handle_new_device_update_async(self) -> None:
|
||||
|
||||
@@ -52,10 +52,11 @@ from synapse.storage.database import (
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import JsonDict, StrCollection
|
||||
from synapse.util import Duration, json_encoder
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.stringutils import parse_and_validate_server_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -1027,6 +1028,40 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||
# loop first time we run this.
|
||||
self._clock.sleep(1)
|
||||
|
||||
async def get_devices_with_messages(
|
||||
self, user_id: str, device_ids: StrCollection
|
||||
) -> StrCollection:
|
||||
"""Get the matching device IDs that have messages in the device inbox."""
|
||||
|
||||
def get_devices_with_messages_txn(
|
||||
txn: LoggingTransaction,
|
||||
batch_device_ids: StrCollection,
|
||||
) -> StrCollection:
|
||||
clause, args = make_in_list_sql_clause(
|
||||
self.database_engine, "device_id", batch_device_ids
|
||||
)
|
||||
sql = f"""
|
||||
SELECT DISTINCT device_id FROM device_inbox
|
||||
WHERE {clause} AND user_id = ?
|
||||
"""
|
||||
args.append(user_id)
|
||||
txn.execute(sql, args)
|
||||
return {row[0] for row in txn}
|
||||
|
||||
results: Set[str] = set()
|
||||
for batch_device_ids in batch_iter(device_ids, 1000):
|
||||
batch_results = await self.db_pool.runInteraction(
|
||||
"get_devices_with_messages",
|
||||
get_devices_with_messages_txn,
|
||||
batch_device_ids,
|
||||
# We don't need to run in a transaction as it's a single query
|
||||
db_autocommit=True,
|
||||
)
|
||||
|
||||
results.update(batch_results)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
|
||||
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
|
||||
|
||||
@@ -282,7 +282,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
|
||||
"count_devices_by_users", count_devices_by_users_txn, user_ids
|
||||
)
|
||||
|
||||
@cached()
|
||||
@cached(tree=True)
|
||||
async def get_device(
|
||||
self, user_id: str, device_id: str
|
||||
) -> Optional[Mapping[str, Any]]:
|
||||
@@ -1861,7 +1861,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
)
|
||||
raise StoreError(500, "Problem storing device.")
|
||||
|
||||
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
|
||||
async def delete_devices(self, user_id: str, device_ids: StrCollection) -> None:
|
||||
"""Deletes several devices.
|
||||
|
||||
Args:
|
||||
@@ -1885,11 +1885,49 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
values=device_ids,
|
||||
keyvalues={"user_id": user_id},
|
||||
)
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn, self.get_device, [(user_id, device_id) for device_id in device_ids]
|
||||
|
||||
# Also delete associated e2e keys.
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="e2e_device_keys_json",
|
||||
keyvalues={"user_id": user_id},
|
||||
column="device_id",
|
||||
values=device_ids,
|
||||
)
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="e2e_one_time_keys_json",
|
||||
keyvalues={"user_id": user_id},
|
||||
column="device_id",
|
||||
values=device_ids,
|
||||
)
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="dehydrated_devices",
|
||||
keyvalues={"user_id": user_id},
|
||||
column="device_id",
|
||||
values=device_ids,
|
||||
)
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="e2e_fallback_keys_json",
|
||||
keyvalues={"user_id": user_id},
|
||||
column="device_id",
|
||||
values=device_ids,
|
||||
)
|
||||
|
||||
for batch in batch_iter(device_ids, 100):
|
||||
# We're bulk deleting potentially many devices at once, so
|
||||
# let's not invalidate the cache for each device individually.
|
||||
# Instead, we will invalidate the cache for the user as a whole.
|
||||
self._invalidate_cache_and_stream(txn, self.get_device, (user_id,))
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id,)
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_e2e_unused_fallback_key_types, (user_id,)
|
||||
)
|
||||
|
||||
for batch in batch_iter(device_ids, 1000):
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_devices", _delete_devices_txn, batch
|
||||
)
|
||||
@@ -2061,32 +2099,36 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
context = get_active_span_text_map()
|
||||
|
||||
def add_device_changes_txn(
|
||||
txn: LoggingTransaction, stream_ids: List[int]
|
||||
txn: LoggingTransaction,
|
||||
batch_device_ids: StrCollection,
|
||||
stream_ids: List[int],
|
||||
) -> None:
|
||||
self._add_device_change_to_stream_txn(
|
||||
txn,
|
||||
user_id,
|
||||
device_ids,
|
||||
batch_device_ids,
|
||||
stream_ids,
|
||||
)
|
||||
|
||||
self._add_device_outbound_room_poke_txn(
|
||||
txn,
|
||||
user_id,
|
||||
device_ids,
|
||||
batch_device_ids,
|
||||
room_ids,
|
||||
stream_ids,
|
||||
context,
|
||||
)
|
||||
|
||||
async with self._device_list_id_gen.get_next_mult(
|
||||
len(device_ids)
|
||||
) as stream_ids:
|
||||
await self.db_pool.runInteraction(
|
||||
"add_device_change_to_stream",
|
||||
add_device_changes_txn,
|
||||
stream_ids,
|
||||
)
|
||||
for batch_device_ids in batch_iter(device_ids, 1000):
|
||||
async with self._device_list_id_gen.get_next_mult(
|
||||
len(device_ids)
|
||||
) as stream_ids:
|
||||
await self.db_pool.runInteraction(
|
||||
"add_device_change_to_stream",
|
||||
add_device_changes_txn,
|
||||
batch_device_ids,
|
||||
stream_ids,
|
||||
)
|
||||
|
||||
return stream_ids[-1]
|
||||
|
||||
|
||||
@@ -593,7 +593,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
@cached(max_entries=10000, tree=True)
|
||||
async def count_e2e_one_time_keys(
|
||||
self, user_id: str, device_id: str
|
||||
) -> Mapping[str, int]:
|
||||
@@ -808,7 +808,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
|
||||
},
|
||||
)
|
||||
|
||||
@cached(max_entries=10000)
|
||||
@cached(max_entries=10000, tree=True)
|
||||
async def get_e2e_unused_fallback_key_types(
|
||||
self, user_id: str, device_id: str
|
||||
) -> Sequence[str]:
|
||||
@@ -1632,46 +1632,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
log_kv({"message": "Device keys stored."})
|
||||
return True
|
||||
|
||||
async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
|
||||
def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
|
||||
log_kv(
|
||||
{
|
||||
"message": "Deleting keys for device",
|
||||
"device_id": device_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="e2e_device_keys_json",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
)
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="e2e_one_time_keys_json",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="dehydrated_devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
)
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="e2e_fallback_keys_json",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_e2e_unused_fallback_key_types, (user_id, device_id)
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn
|
||||
)
|
||||
|
||||
def _set_e2e_cross_signing_key_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
|
||||
@@ -40,14 +40,16 @@ from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.databases.main.stats import StatsStore
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.storage.util.id_generators import IdGenerator
|
||||
from synapse.storage.util.sequence import build_sequence_generator
|
||||
from synapse.types import JsonDict, UserID, UserInfo
|
||||
from synapse.types import JsonDict, StrCollection, UserID, UserInfo
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -2801,6 +2803,81 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
|
||||
|
||||
return await self.db_pool.runInteraction("user_delete_access_tokens", f)
|
||||
|
||||
async def user_delete_access_tokens_for_devices(
|
||||
self,
|
||||
user_id: str,
|
||||
device_ids: StrCollection,
|
||||
) -> List[Tuple[str, int, Optional[str]]]:
|
||||
"""
|
||||
Invalidate access and refresh tokens belonging to a user
|
||||
|
||||
Args:
|
||||
user_id: ID of user the tokens belong to
|
||||
device_ids: The devices to delete tokens for.
|
||||
Returns:
|
||||
A tuple of (token, token id, device id) for each of the deleted tokens
|
||||
"""
|
||||
|
||||
def user_delete_access_tokens_for_devices_txn(
|
||||
txn: LoggingTransaction, batch_device_ids: StrCollection
|
||||
) -> List[Tuple[str, int, Optional[str]]]:
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="refresh_tokens",
|
||||
keyvalues={"user_id": user_id},
|
||||
column="device_id",
|
||||
values=batch_device_ids,
|
||||
)
|
||||
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "device_id", batch_device_ids
|
||||
)
|
||||
args.append(user_id)
|
||||
|
||||
if self.database_engine.supports_returning:
|
||||
sql = f"""
|
||||
DELETE FROM access_tokens
|
||||
WHERE {clause} AND user_id = ?
|
||||
RETURNING token, id, device_id
|
||||
"""
|
||||
txn.execute(sql, args)
|
||||
tokens_and_devices = txn.fetchall()
|
||||
else:
|
||||
tokens_and_devices = self.db_pool.simple_select_many_txn(
|
||||
txn,
|
||||
table="access_tokens",
|
||||
column="device_id",
|
||||
iterable=batch_device_ids,
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("token", "id", "device_id"),
|
||||
)
|
||||
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="access_tokens",
|
||||
keyvalues={"user_id": user_id},
|
||||
column="device_id",
|
||||
values=batch_device_ids,
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream_bulk(
|
||||
txn,
|
||||
self.get_user_by_access_token,
|
||||
[(t[0],) for t in tokens_and_devices],
|
||||
)
|
||||
return tokens_and_devices
|
||||
|
||||
results = []
|
||||
for batch_device_ids in batch_iter(device_ids, 1000):
|
||||
tokens_and_devices = await self.db_pool.runInteraction(
|
||||
"user_delete_access_tokens_for_devices",
|
||||
user_delete_access_tokens_for_devices_txn,
|
||||
batch_device_ids,
|
||||
)
|
||||
results.extend(tokens_and_devices)
|
||||
|
||||
return results
|
||||
|
||||
async def delete_access_token(self, access_token: str) -> None:
|
||||
def f(txn: LoggingTransaction) -> None:
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
|
||||
Reference in New Issue
Block a user