Merge commit '7f837959e' into anoa/dinsic_release_1_21_x
* commit '7f837959e': Convert directory, e2e_room_keys, end_to_end_keys, monthly_active_users database to async (#8042) Convert additional database stores to async/await (#8045)
This commit is contained in:
1
changelog.d/8042.misc
Normal file
1
changelog.d/8042.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
1
changelog.d/8045.misc
Normal file
1
changelog.d/8045.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
@@ -14,8 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
@@ -82,21 +81,19 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
||||
"devices_last_seen", self._devices_last_seen_update
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _remove_user_ip_nonunique(self, progress, batch_size):
|
||||
async def _remove_user_ip_nonunique(self, progress, batch_size):
|
||||
def f(conn):
|
||||
txn = conn.cursor()
|
||||
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
|
||||
txn.close()
|
||||
|
||||
yield self.db_pool.runWithConnection(f)
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.runWithConnection(f)
|
||||
await self.db_pool.updates._end_background_update(
|
||||
"user_ips_drop_nonunique_index"
|
||||
)
|
||||
return 1
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _analyze_user_ip(self, progress, batch_size):
|
||||
async def _analyze_user_ip(self, progress, batch_size):
|
||||
# Background update to analyze user_ips table before we run the
|
||||
# deduplication background update. The table may not have been analyzed
|
||||
# for ages due to the table locks.
|
||||
@@ -106,14 +103,13 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
||||
def user_ips_analyze(txn):
|
||||
txn.execute("ANALYZE user_ips")
|
||||
|
||||
yield self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
|
||||
await self.db_pool.runInteraction("user_ips_analyze", user_ips_analyze)
|
||||
|
||||
yield self.db_pool.updates._end_background_update("user_ips_analyze")
|
||||
await self.db_pool.updates._end_background_update("user_ips_analyze")
|
||||
|
||||
return 1
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _remove_user_ip_dupes(self, progress, batch_size):
|
||||
async def _remove_user_ip_dupes(self, progress, batch_size):
|
||||
# This works function works by scanning the user_ips table in batches
|
||||
# based on `last_seen`. For each row in a batch it searches the rest of
|
||||
# the table to see if there are any duplicates, if there are then they
|
||||
@@ -140,7 +136,7 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
||||
return None
|
||||
|
||||
# Get a last seen that has roughly `batch_size` since `begin_last_seen`
|
||||
end_last_seen = yield self.db_pool.runInteraction(
|
||||
end_last_seen = await self.db_pool.runInteraction(
|
||||
"user_ips_dups_get_last_seen", get_last_seen
|
||||
)
|
||||
|
||||
@@ -275,15 +271,14 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
||||
txn, "user_ips_remove_dupes", {"last_seen": end_last_seen}
|
||||
)
|
||||
|
||||
yield self.db_pool.runInteraction("user_ips_dups_remove", remove)
|
||||
await self.db_pool.runInteraction("user_ips_dups_remove", remove)
|
||||
|
||||
if last:
|
||||
yield self.db_pool.updates._end_background_update("user_ips_remove_dupes")
|
||||
await self.db_pool.updates._end_background_update("user_ips_remove_dupes")
|
||||
|
||||
return batch_size
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _devices_last_seen_update(self, progress, batch_size):
|
||||
async def _devices_last_seen_update(self, progress, batch_size):
|
||||
"""Background update to insert last seen info into devices table
|
||||
"""
|
||||
|
||||
@@ -346,12 +341,12 @@ class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
||||
|
||||
return len(rows)
|
||||
|
||||
updated = yield self.db_pool.runInteraction(
|
||||
updated = await self.db_pool.runInteraction(
|
||||
"_devices_last_seen_update", _devices_last_seen_update_txn
|
||||
)
|
||||
|
||||
if not updated:
|
||||
yield self.db_pool.updates._end_background_update("devices_last_seen")
|
||||
await self.db_pool.updates._end_background_update("devices_last_seen")
|
||||
|
||||
return updated
|
||||
|
||||
@@ -460,25 +455,25 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
|
||||
# Failed to upsert, log and continue
|
||||
logger.error("Failed to insert client IP %r: %r", entry, e)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_last_client_ip_by_device(self, user_id, device_id):
|
||||
async def get_last_client_ip_by_device(
|
||||
self, user_id: str, device_id: Optional[str]
|
||||
) -> Dict[Tuple[str, str], dict]:
|
||||
"""For each device_id listed, give the user_ip it was last seen on
|
||||
|
||||
Args:
|
||||
user_id (str)
|
||||
device_id (str): If None fetches all devices for the user
|
||||
user_id: The user to fetch devices for.
|
||||
device_id: If None fetches all devices for the user
|
||||
|
||||
Returns:
|
||||
defer.Deferred: resolves to a dict, where the keys
|
||||
are (user_id, device_id) tuples. The values are also dicts, with
|
||||
keys giving the column names
|
||||
A dictionary mapping a tuple of (user_id, device_id) to dicts, with
|
||||
keys giving the column names from the devices table.
|
||||
"""
|
||||
|
||||
keyvalues = {"user_id": user_id}
|
||||
if device_id is not None:
|
||||
keyvalues["device_id"] = device_id
|
||||
|
||||
res = yield self.db_pool.simple_select_list(
|
||||
res = await self.db_pool.simple_select_list(
|
||||
table="devices",
|
||||
keyvalues=keyvalues,
|
||||
retcols=("user_id", "ip", "user_agent", "device_id", "last_seen"),
|
||||
@@ -500,8 +495,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
|
||||
}
|
||||
return ret
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_ip_and_agents(self, user):
|
||||
async def get_user_ip_and_agents(self, user):
|
||||
user_id = user.to_string()
|
||||
results = {}
|
||||
|
||||
@@ -511,7 +505,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
|
||||
user_agent, _, last_seen = self._batch_row_update[key]
|
||||
results[(access_token, ip)] = (user_agent, last_seen)
|
||||
|
||||
rows = yield self.db_pool.simple_select_list(
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
table="user_ips",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=["access_token", "ip", "user_agent", "last_seen"],
|
||||
|
||||
@@ -136,7 +136,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
master_key_by_user = {}
|
||||
self_signing_key_by_user = {}
|
||||
for user in users:
|
||||
cross_signing_key = yield self.get_e2e_cross_signing_key(user, "master")
|
||||
cross_signing_key = yield defer.ensureDeferred(
|
||||
self.get_e2e_cross_signing_key(user, "master")
|
||||
)
|
||||
if cross_signing_key:
|
||||
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||
cross_signing_key
|
||||
@@ -149,8 +151,8 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
"device_id": verify_key.version,
|
||||
}
|
||||
|
||||
cross_signing_key = yield self.get_e2e_cross_signing_key(
|
||||
user, "self_signing"
|
||||
cross_signing_key = yield defer.ensureDeferred(
|
||||
self.get_e2e_cross_signing_key(user, "self_signing")
|
||||
)
|
||||
if cross_signing_key:
|
||||
key_id, verify_key = get_verify_key_from_cross_signing_key(
|
||||
@@ -246,7 +248,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
destination (str): The host the device updates are intended for
|
||||
from_stream_id (int): The minimum stream_id to filter updates by, exclusive
|
||||
query_map (Dict[(str, str): (int, str|None)]): Dictionary mapping
|
||||
user_id/device_id to update stream_id and the relevent json-encoded
|
||||
user_id/device_id to update stream_id and the relevant json-encoded
|
||||
opentracing context
|
||||
|
||||
Returns:
|
||||
@@ -599,7 +601,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
between the requested tokens due to the limit.
|
||||
|
||||
The token returned can be used in a subsequent call to this
|
||||
function to get further updatees.
|
||||
function to get further updates.
|
||||
|
||||
The updates are a list of 2-tuples of stream ID and the row data
|
||||
"""
|
||||
|
||||
@@ -14,30 +14,29 @@
|
||||
# limitations under the License.
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Optional
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.types import RoomAlias
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
RoomAliasMapping = namedtuple("RoomAliasMapping", ("room_id", "room_alias", "servers"))
|
||||
|
||||
|
||||
class DirectoryWorkerStore(SQLBaseStore):
|
||||
@defer.inlineCallbacks
|
||||
def get_association_from_room_alias(self, room_alias):
|
||||
""" Get's the room_id and server list for a given room_alias
|
||||
async def get_association_from_room_alias(
|
||||
self, room_alias: RoomAlias
|
||||
) -> Optional[RoomAliasMapping]:
|
||||
"""Gets the room_id and server list for a given room_alias
|
||||
|
||||
Args:
|
||||
room_alias (RoomAlias)
|
||||
room_alias: The alias to translate to an ID.
|
||||
|
||||
Returns:
|
||||
Deferred: results in namedtuple with keys "room_id" and
|
||||
"servers" or None if no association can be found
|
||||
The room alias mapping or None if no association can be found.
|
||||
"""
|
||||
room_id = yield self.db_pool.simple_select_one_onecol(
|
||||
room_id = await self.db_pool.simple_select_one_onecol(
|
||||
"room_aliases",
|
||||
{"room_alias": room_alias.to_string()},
|
||||
"room_id",
|
||||
@@ -48,7 +47,7 @@ class DirectoryWorkerStore(SQLBaseStore):
|
||||
if not room_id:
|
||||
return None
|
||||
|
||||
servers = yield self.db_pool.simple_select_onecol(
|
||||
servers = await self.db_pool.simple_select_onecol(
|
||||
"room_alias_servers",
|
||||
{"room_alias": room_alias.to_string()},
|
||||
"server",
|
||||
@@ -79,18 +78,20 @@ class DirectoryWorkerStore(SQLBaseStore):
|
||||
|
||||
|
||||
class DirectoryStore(DirectoryWorkerStore):
|
||||
@defer.inlineCallbacks
|
||||
def create_room_alias_association(self, room_alias, room_id, servers, creator=None):
|
||||
async def create_room_alias_association(
|
||||
self,
|
||||
room_alias: RoomAlias,
|
||||
room_id: str,
|
||||
servers: Iterable[str],
|
||||
creator: Optional[str] = None,
|
||||
) -> None:
|
||||
""" Creates an association between a room alias and room_id/servers
|
||||
|
||||
Args:
|
||||
room_alias (RoomAlias)
|
||||
room_id (str)
|
||||
servers (list)
|
||||
creator (str): Optional user_id of creator.
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
room_alias: The alias to create.
|
||||
room_id: The target of the alias.
|
||||
servers: A list of servers through which it may be possible to join the room
|
||||
creator: Optional user_id of creator.
|
||||
"""
|
||||
|
||||
def alias_txn(txn):
|
||||
@@ -118,24 +119,22 @@ class DirectoryStore(DirectoryWorkerStore):
|
||||
)
|
||||
|
||||
try:
|
||||
ret = yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"create_room_alias_association", alias_txn
|
||||
)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
raise SynapseError(
|
||||
409, "Room alias %s already exists" % room_alias.to_string()
|
||||
)
|
||||
return ret
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_room_alias(self, room_alias):
|
||||
room_id = yield self.db_pool.runInteraction(
|
||||
async def delete_room_alias(self, room_alias: RoomAlias) -> str:
|
||||
room_id = await self.db_pool.runInteraction(
|
||||
"delete_room_alias", self._delete_room_alias_txn, room_alias
|
||||
)
|
||||
|
||||
return room_id
|
||||
|
||||
def _delete_room_alias_txn(self, txn, room_alias):
|
||||
def _delete_room_alias_txn(self, txn, room_alias: RoomAlias) -> str:
|
||||
txn.execute(
|
||||
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
|
||||
(room_alias.to_string(),),
|
||||
|
||||
@@ -14,8 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.logging.opentracing import log_kv, trace
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
@@ -23,8 +21,9 @@ from synapse.util import json_encoder
|
||||
|
||||
|
||||
class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
@defer.inlineCallbacks
|
||||
def update_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
|
||||
async def update_e2e_room_key(
|
||||
self, user_id, version, room_id, session_id, room_key
|
||||
):
|
||||
"""Replaces the encrypted E2E room key for a given session in a given backup
|
||||
|
||||
Args:
|
||||
@@ -37,7 +36,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
StoreError
|
||||
"""
|
||||
|
||||
yield self.db_pool.simple_update_one(
|
||||
await self.db_pool.simple_update_one(
|
||||
table="e2e_room_keys",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
@@ -54,8 +53,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
desc="update_e2e_room_key",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_e2e_room_keys(self, user_id, version, room_keys):
|
||||
async def add_e2e_room_keys(self, user_id, version, room_keys):
|
||||
"""Bulk add room keys to a given backup.
|
||||
|
||||
Args:
|
||||
@@ -88,13 +86,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
}
|
||||
)
|
||||
|
||||
yield self.db_pool.simple_insert_many(
|
||||
await self.db_pool.simple_insert_many(
|
||||
table="e2e_room_keys", values=values, desc="add_e2e_room_keys"
|
||||
)
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
|
||||
async def get_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
|
||||
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
|
||||
room, or a given session.
|
||||
|
||||
@@ -109,7 +106,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
the backup (or for the specified room)
|
||||
|
||||
Returns:
|
||||
A deferred list of dicts giving the session_data and message metadata for
|
||||
A list of dicts giving the session_data and message metadata for
|
||||
these room keys.
|
||||
"""
|
||||
|
||||
@@ -124,7 +121,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
if session_id:
|
||||
keyvalues["session_id"] = session_id
|
||||
|
||||
rows = yield self.db_pool.simple_select_list(
|
||||
rows = await self.db_pool.simple_select_list(
|
||||
table="e2e_room_keys",
|
||||
keyvalues=keyvalues,
|
||||
retcols=(
|
||||
@@ -242,8 +239,9 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def delete_e2e_room_keys(self, user_id, version, room_id=None, session_id=None):
|
||||
async def delete_e2e_room_keys(
|
||||
self, user_id, version, room_id=None, session_id=None
|
||||
):
|
||||
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
|
||||
room or a given session.
|
||||
|
||||
@@ -258,7 +256,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
the backup (or for the specified room)
|
||||
|
||||
Returns:
|
||||
A deferred of the deletion transaction
|
||||
The deletion transaction
|
||||
"""
|
||||
|
||||
keyvalues = {"user_id": user_id, "version": int(version)}
|
||||
@@ -267,7 +265,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
if session_id:
|
||||
keyvalues["session_id"] = session_id
|
||||
|
||||
yield self.db_pool.simple_delete(
|
||||
await self.db_pool.simple_delete(
|
||||
table="e2e_room_keys", keyvalues=keyvalues, desc="delete_e2e_room_keys"
|
||||
)
|
||||
|
||||
|
||||
@@ -14,12 +14,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.enterprise.adbapi import Connection
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
@@ -31,8 +30,7 @@ from synapse.util.iterutils import batch_iter
|
||||
|
||||
class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
def get_e2e_device_keys(
|
||||
async def get_e2e_device_keys(
|
||||
self, query_list, include_all_devices=False, include_deleted_devices=False
|
||||
):
|
||||
"""Fetch a list of device keys.
|
||||
@@ -52,7 +50,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
if not query_list:
|
||||
return {}
|
||||
|
||||
results = yield self.db_pool.runInteraction(
|
||||
results = await self.db_pool.runInteraction(
|
||||
"get_e2e_device_keys",
|
||||
self._get_e2e_device_keys_txn,
|
||||
query_list,
|
||||
@@ -175,8 +173,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
log_kv(result)
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
|
||||
async def get_e2e_one_time_keys(
|
||||
self, user_id: str, device_id: str, key_ids: List[str]
|
||||
) -> Dict[Tuple[str, str], str]:
|
||||
"""Retrieve a number of one-time keys for a user
|
||||
|
||||
Args:
|
||||
@@ -186,11 +185,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
retrieve
|
||||
|
||||
Returns:
|
||||
deferred resolving to Dict[(str, str), str]: map from (algorithm,
|
||||
key_id) to json string for key
|
||||
A map from (algorithm, key_id) to json string for key
|
||||
"""
|
||||
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="e2e_one_time_keys_json",
|
||||
column="key_id",
|
||||
iterable=key_ids,
|
||||
@@ -202,17 +200,21 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
log_kv({"message": "Fetched one time keys for user", "one_time_keys": result})
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
|
||||
async def add_e2e_one_time_keys(
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
time_now: int,
|
||||
new_keys: Iterable[Tuple[str, str, str]],
|
||||
) -> None:
|
||||
"""Insert some new one time keys for a device. Errors if any of the
|
||||
keys already exist.
|
||||
|
||||
Args:
|
||||
user_id(str): id of user to get keys for
|
||||
device_id(str): id of device to get keys for
|
||||
time_now(long): insertion time to record (ms since epoch)
|
||||
new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
|
||||
(algorithm, key_id, key json)
|
||||
user_id: id of user to get keys for
|
||||
device_id: id of device to get keys for
|
||||
time_now: insertion time to record (ms since epoch)
|
||||
new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
|
||||
"""
|
||||
|
||||
def _add_e2e_one_time_keys(txn):
|
||||
@@ -242,7 +244,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
txn, self.count_e2e_one_time_keys, (user_id, device_id)
|
||||
)
|
||||
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
|
||||
)
|
||||
|
||||
@@ -269,22 +271,23 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
"count_e2e_one_time_keys", _count_e2e_one_time_keys
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_e2e_cross_signing_key(self, user_id, key_type, from_user_id=None):
|
||||
async def get_e2e_cross_signing_key(
|
||||
self, user_id: str, key_type: str, from_user_id: Optional[str] = None
|
||||
) -> Optional[dict]:
|
||||
"""Returns a user's cross-signing key.
|
||||
|
||||
Args:
|
||||
user_id (str): the user whose key is being requested
|
||||
key_type (str): the type of key that is being requested: either 'master'
|
||||
user_id: the user whose key is being requested
|
||||
key_type: the type of key that is being requested: either 'master'
|
||||
for a master key, 'self_signing' for a self-signing key, or
|
||||
'user_signing' for a user-signing key
|
||||
from_user_id (str): if specified, signatures made by this user on
|
||||
from_user_id: if specified, signatures made by this user on
|
||||
the self-signing key will be included in the result
|
||||
|
||||
Returns:
|
||||
dict of the key data or None if not found
|
||||
"""
|
||||
res = yield self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
|
||||
res = await self.get_e2e_cross_signing_keys_bulk([user_id], from_user_id)
|
||||
user_keys = res.get(user_id)
|
||||
if not user_keys:
|
||||
return None
|
||||
@@ -450,28 +453,26 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
|
||||
return keys
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_e2e_cross_signing_keys_bulk(
|
||||
self, user_ids: List[str], from_user_id: str = None
|
||||
) -> defer.Deferred:
|
||||
async def get_e2e_cross_signing_keys_bulk(
|
||||
self, user_ids: List[str], from_user_id: Optional[str] = None
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Returns the cross-signing keys for a set of users.
|
||||
|
||||
Args:
|
||||
user_ids (list[str]): the users whose keys are being requested
|
||||
from_user_id (str): if specified, signatures made by this user on
|
||||
user_ids: the users whose keys are being requested
|
||||
from_user_id: if specified, signatures made by this user on
|
||||
the self-signing keys will be included in the result
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
|
||||
key data. If a user's cross-signing keys were not found, either
|
||||
their user ID will not be in the dict, or their user ID will map
|
||||
to None.
|
||||
A map of user ID to key type to key data. If a user's cross-signing
|
||||
keys were not found, either their user ID will not be in the dict,
|
||||
or their user ID will map to None.
|
||||
"""
|
||||
|
||||
result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
|
||||
result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
|
||||
|
||||
if from_user_id:
|
||||
result = yield self.db_pool.runInteraction(
|
||||
result = await self.db_pool.runInteraction(
|
||||
"get_e2e_cross_signing_signatures",
|
||||
self._get_e2e_cross_signing_signatures_txn,
|
||||
result,
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
import logging
|
||||
from typing import List
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
|
||||
from synapse.util.caches.descriptors import cached
|
||||
@@ -252,16 +250,12 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
||||
"reap_monthly_active_users", _reap_users, reserved_users
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def upsert_monthly_active_user(self, user_id):
|
||||
async def upsert_monthly_active_user(self, user_id: str) -> None:
|
||||
"""Updates or inserts the user into the monthly active user table, which
|
||||
is used to track the current MAU usage of the server
|
||||
|
||||
Args:
|
||||
user_id (str): user to add/update
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
user_id: user to add/update
|
||||
"""
|
||||
# Support user never to be included in MAU stats. Note I can't easily call this
|
||||
# from upsert_monthly_active_user_txn because then I need a _txn form of
|
||||
@@ -271,11 +265,11 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
||||
# _initialise_reserved_users reasoning that it would be very strange to
|
||||
# include a support user in this context.
|
||||
|
||||
is_support = yield self.is_support_user(user_id)
|
||||
is_support = await self.is_support_user(user_id)
|
||||
if is_support:
|
||||
return
|
||||
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"upsert_monthly_active_user", self.upsert_monthly_active_user_txn, user_id
|
||||
)
|
||||
|
||||
@@ -322,8 +316,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
||||
|
||||
return is_insert
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def populate_monthly_active_users(self, user_id):
|
||||
async def populate_monthly_active_users(self, user_id):
|
||||
"""Checks on the state of monthly active user limits and optionally
|
||||
add the user to the monthly active tables
|
||||
|
||||
@@ -332,14 +325,14 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
||||
"""
|
||||
if self._limit_usage_by_mau or self._mau_stats_only:
|
||||
# Trial users and guests should not be included as part of MAU group
|
||||
is_guest = yield self.is_guest(user_id)
|
||||
is_guest = await self.is_guest(user_id)
|
||||
if is_guest:
|
||||
return
|
||||
is_trial = yield self.is_trial_user(user_id)
|
||||
is_trial = await self.is_trial_user(user_id)
|
||||
if is_trial:
|
||||
return
|
||||
|
||||
last_seen_timestamp = yield self.user_last_seen_monthly_active(user_id)
|
||||
last_seen_timestamp = await self.user_last_seen_monthly_active(user_id)
|
||||
now = self.hs.get_clock().time_msec()
|
||||
|
||||
# We want to reduce to the total number of db writes, and are happy
|
||||
@@ -352,10 +345,10 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore):
|
||||
# False, there is no point in checking get_monthly_active_count - it
|
||||
# adds no value and will break the logic if max_mau_value is exceeded.
|
||||
if not self._limit_usage_by_mau:
|
||||
yield self.upsert_monthly_active_user(user_id)
|
||||
await self.upsert_monthly_active_user(user_id)
|
||||
else:
|
||||
count = yield self.get_monthly_active_count()
|
||||
count = await self.get_monthly_active_count()
|
||||
if count < self._max_mau_value:
|
||||
yield self.upsert_monthly_active_user(user_id)
|
||||
await self.upsert_monthly_active_user(user_id)
|
||||
elif now - last_seen_timestamp > LAST_SEEN_GRANULARITY:
|
||||
yield self.upsert_monthly_active_user(user_id)
|
||||
await self.upsert_monthly_active_user(user_id)
|
||||
|
||||
@@ -16,8 +16,7 @@
|
||||
import logging
|
||||
import re
|
||||
from collections import namedtuple
|
||||
|
||||
from twisted.internet import defer
|
||||
from typing import List, Optional
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
@@ -114,8 +113,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME, self._background_reindex_gin_search
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_reindex_search(self, progress, batch_size):
|
||||
async def _background_reindex_search(self, progress, batch_size):
|
||||
# we work through the events table from highest stream id to lowest
|
||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
@@ -206,19 +204,18 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
|
||||
return len(event_search_rows)
|
||||
|
||||
result = yield self.db_pool.runInteraction(
|
||||
result = await self.db_pool.runInteraction(
|
||||
self.EVENT_SEARCH_UPDATE_NAME, reindex_search_txn
|
||||
)
|
||||
|
||||
if not result:
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.EVENT_SEARCH_UPDATE_NAME
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_reindex_gin_search(self, progress, batch_size):
|
||||
async def _background_reindex_gin_search(self, progress, batch_size):
|
||||
"""This handles old synapses which used GIST indexes, if any;
|
||||
converting them back to be GIN as per the actual schema.
|
||||
"""
|
||||
@@ -255,15 +252,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
conn.set_session(autocommit=False)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
yield self.db_pool.runWithConnection(create_index)
|
||||
await self.db_pool.runWithConnection(create_index)
|
||||
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.EVENT_SEARCH_USE_GIN_POSTGRES_NAME
|
||||
)
|
||||
return 1
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _background_reindex_search_order(self, progress, batch_size):
|
||||
async def _background_reindex_search_order(self, progress, batch_size):
|
||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
rows_inserted = progress.get("rows_inserted", 0)
|
||||
@@ -288,12 +284,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
)
|
||||
conn.set_session(autocommit=False)
|
||||
|
||||
yield self.db_pool.runWithConnection(create_index)
|
||||
await self.db_pool.runWithConnection(create_index)
|
||||
|
||||
pg = dict(progress)
|
||||
pg["have_added_indexes"] = True
|
||||
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
|
||||
self.db_pool.updates._background_update_progress_txn,
|
||||
self.EVENT_SEARCH_ORDER_UPDATE_NAME,
|
||||
@@ -331,12 +327,12 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
|
||||
return len(rows), True
|
||||
|
||||
num_rows, finished = yield self.db_pool.runInteraction(
|
||||
num_rows, finished = await self.db_pool.runInteraction(
|
||||
self.EVENT_SEARCH_ORDER_UPDATE_NAME, reindex_search_txn
|
||||
)
|
||||
|
||||
if not finished:
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
self.EVENT_SEARCH_ORDER_UPDATE_NAME
|
||||
)
|
||||
|
||||
@@ -347,8 +343,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(SearchStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def search_msgs(self, room_ids, search_term, keys):
|
||||
async def search_msgs(self, room_ids, search_term, keys):
|
||||
"""Performs a full text search over events with given keys.
|
||||
|
||||
Args:
|
||||
@@ -425,7 +420,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
# entire table from the database.
|
||||
sql += " ORDER BY rank DESC LIMIT 500"
|
||||
|
||||
results = yield self.db_pool.execute(
|
||||
results = await self.db_pool.execute(
|
||||
"search_msgs", self.db_pool.cursor_to_dict, sql, *args
|
||||
)
|
||||
|
||||
@@ -433,7 +428,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
|
||||
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
|
||||
# search results (which is a data leak)
|
||||
events = yield self.get_events_as_list(
|
||||
events = await self.get_events_as_list(
|
||||
[r["event_id"] for r in results],
|
||||
redact_behaviour=EventRedactBehaviour.BLOCK,
|
||||
)
|
||||
@@ -442,11 +437,11 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
|
||||
highlights = None
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
highlights = yield self._find_highlights_in_postgres(search_query, events)
|
||||
highlights = await self._find_highlights_in_postgres(search_query, events)
|
||||
|
||||
count_sql += " GROUP BY room_id"
|
||||
|
||||
count_results = yield self.db_pool.execute(
|
||||
count_results = await self.db_pool.execute(
|
||||
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
|
||||
)
|
||||
|
||||
@@ -462,19 +457,25 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
"count": count,
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def search_rooms(self, room_ids, search_term, keys, limit, pagination_token=None):
|
||||
async def search_rooms(
|
||||
self,
|
||||
room_ids: List[str],
|
||||
search_term: str,
|
||||
keys: List[str],
|
||||
limit,
|
||||
pagination_token: Optional[str] = None,
|
||||
) -> List[dict]:
|
||||
"""Performs a full text search over events with given keys.
|
||||
|
||||
Args:
|
||||
room_id (list): The room_ids to search in
|
||||
search_term (str): Search term to search for
|
||||
keys (list): List of keys to search in, currently supports
|
||||
"content.body", "content.name", "content.topic"
|
||||
pagination_token (str): A pagination token previously returned
|
||||
room_ids: The room_ids to search in
|
||||
search_term: Search term to search for
|
||||
keys: List of keys to search in, currently supports "content.body",
|
||||
"content.name", "content.topic"
|
||||
pagination_token: A pagination token previously returned
|
||||
|
||||
Returns:
|
||||
list of dicts
|
||||
Each match as a dictionary.
|
||||
"""
|
||||
clauses = []
|
||||
|
||||
@@ -577,7 +578,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
|
||||
args.append(limit)
|
||||
|
||||
results = yield self.db_pool.execute(
|
||||
results = await self.db_pool.execute(
|
||||
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
|
||||
)
|
||||
|
||||
@@ -585,7 +586,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
|
||||
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
|
||||
# search results (which is a data leak)
|
||||
events = yield self.get_events_as_list(
|
||||
events = await self.get_events_as_list(
|
||||
[r["event_id"] for r in results],
|
||||
redact_behaviour=EventRedactBehaviour.BLOCK,
|
||||
)
|
||||
@@ -594,11 +595,11 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
|
||||
highlights = None
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
highlights = yield self._find_highlights_in_postgres(search_query, events)
|
||||
highlights = await self._find_highlights_in_postgres(search_query, events)
|
||||
|
||||
count_sql += " GROUP BY room_id"
|
||||
|
||||
count_results = yield self.db_pool.execute(
|
||||
count_results = await self.db_pool.execute(
|
||||
"search_rooms_count", self.db_pool.cursor_to_dict, count_sql, *count_args
|
||||
)
|
||||
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
@@ -40,9 +38,8 @@ class SignatureWorkerStore(SQLBaseStore):
|
||||
|
||||
return self.db_pool.runInteraction("get_event_reference_hashes", f)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_event_hashes(self, event_ids):
|
||||
hashes = yield self.get_event_reference_hashes(event_ids)
|
||||
async def add_event_hashes(self, event_ids):
|
||||
hashes = await self.get_event_reference_hashes(event_ids)
|
||||
hashes = {
|
||||
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
|
||||
for e_id, h in hashes.items()
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
import logging
|
||||
import re
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, JoinRules
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.state import StateFilter
|
||||
@@ -59,8 +57,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
"populate_user_directory_cleanup", self._populate_user_directory_cleanup
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _populate_user_directory_createtables(self, progress, batch_size):
|
||||
async def _populate_user_directory_createtables(self, progress, batch_size):
|
||||
|
||||
# Get all the rooms that we want to process.
|
||||
def _make_staging_area(txn):
|
||||
@@ -102,45 +99,43 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
|
||||
self.db_pool.simple_insert_many_txn(txn, TEMP_TABLE + "_users", users)
|
||||
|
||||
new_pos = yield self.get_max_stream_id_in_current_state_deltas()
|
||||
yield self.db_pool.runInteraction(
|
||||
new_pos = await self.get_max_stream_id_in_current_state_deltas()
|
||||
await self.db_pool.runInteraction(
|
||||
"populate_user_directory_temp_build", _make_staging_area
|
||||
)
|
||||
yield self.db_pool.simple_insert(
|
||||
await self.db_pool.simple_insert(
|
||||
TEMP_TABLE + "_position", {"position": new_pos}
|
||||
)
|
||||
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
"populate_user_directory_createtables"
|
||||
)
|
||||
return 1
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _populate_user_directory_cleanup(self, progress, batch_size):
|
||||
async def _populate_user_directory_cleanup(self, progress, batch_size):
|
||||
"""
|
||||
Update the user directory stream position, then clean up the old tables.
|
||||
"""
|
||||
position = yield self.db_pool.simple_select_one_onecol(
|
||||
position = await self.db_pool.simple_select_one_onecol(
|
||||
TEMP_TABLE + "_position", None, "position"
|
||||
)
|
||||
yield self.update_user_directory_stream_pos(position)
|
||||
await self.update_user_directory_stream_pos(position)
|
||||
|
||||
def _delete_staging_area(txn):
|
||||
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_rooms")
|
||||
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_users")
|
||||
txn.execute("DROP TABLE IF EXISTS " + TEMP_TABLE + "_position")
|
||||
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"populate_user_directory_cleanup", _delete_staging_area
|
||||
)
|
||||
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
"populate_user_directory_cleanup"
|
||||
)
|
||||
return 1
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _populate_user_directory_process_rooms(self, progress, batch_size):
|
||||
async def _populate_user_directory_process_rooms(self, progress, batch_size):
|
||||
"""
|
||||
Args:
|
||||
progress (dict)
|
||||
@@ -151,7 +146,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
|
||||
# If we don't have progress filed, delete everything.
|
||||
if not progress:
|
||||
yield self.delete_all_from_user_dir()
|
||||
await self.delete_all_from_user_dir()
|
||||
|
||||
def _get_next_batch(txn):
|
||||
# Only fetch 250 rooms, so we don't fetch too many at once, even
|
||||
@@ -176,13 +171,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
|
||||
return rooms_to_work_on
|
||||
|
||||
rooms_to_work_on = yield self.db_pool.runInteraction(
|
||||
rooms_to_work_on = await self.db_pool.runInteraction(
|
||||
"populate_user_directory_temp_read", _get_next_batch
|
||||
)
|
||||
|
||||
# No more rooms -- complete the transaction.
|
||||
if not rooms_to_work_on:
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
"populate_user_directory_process_rooms"
|
||||
)
|
||||
return 1
|
||||
@@ -195,21 +190,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
processed_event_count = 0
|
||||
|
||||
for room_id, event_count in rooms_to_work_on:
|
||||
is_in_room = yield self.is_host_joined(room_id, self.server_name)
|
||||
is_in_room = await self.is_host_joined(room_id, self.server_name)
|
||||
|
||||
if is_in_room:
|
||||
is_public = yield self.is_room_world_readable_or_publicly_joinable(
|
||||
is_public = await self.is_room_world_readable_or_publicly_joinable(
|
||||
room_id
|
||||
)
|
||||
|
||||
users_with_profile = yield defer.ensureDeferred(
|
||||
state.get_current_users_in_room(room_id)
|
||||
)
|
||||
users_with_profile = await state.get_current_users_in_room(room_id)
|
||||
user_ids = set(users_with_profile)
|
||||
|
||||
# Update each user in the user directory.
|
||||
for user_id, profile in users_with_profile.items():
|
||||
yield self.update_profile_in_user_dir(
|
||||
await self.update_profile_in_user_dir(
|
||||
user_id, profile.display_name, profile.avatar_url
|
||||
)
|
||||
|
||||
@@ -223,7 +216,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
to_insert.add(user_id)
|
||||
|
||||
if to_insert:
|
||||
yield self.add_users_in_public_rooms(room_id, to_insert)
|
||||
await self.add_users_in_public_rooms(room_id, to_insert)
|
||||
to_insert.clear()
|
||||
else:
|
||||
for user_id in user_ids:
|
||||
@@ -243,22 +236,22 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
# If it gets too big, stop and write to the database
|
||||
# to prevent storing too much in RAM.
|
||||
if len(to_insert) >= self.SHARE_PRIVATE_WORKING_SET:
|
||||
yield self.add_users_who_share_private_room(
|
||||
await self.add_users_who_share_private_room(
|
||||
room_id, to_insert
|
||||
)
|
||||
to_insert.clear()
|
||||
|
||||
if to_insert:
|
||||
yield self.add_users_who_share_private_room(room_id, to_insert)
|
||||
await self.add_users_who_share_private_room(room_id, to_insert)
|
||||
to_insert.clear()
|
||||
|
||||
# We've finished a room. Delete it from the table.
|
||||
yield self.db_pool.simple_delete_one(
|
||||
await self.db_pool.simple_delete_one(
|
||||
TEMP_TABLE + "_rooms", {"room_id": room_id}
|
||||
)
|
||||
# Update the remaining counter.
|
||||
progress["remaining"] -= 1
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"populate_user_directory",
|
||||
self.db_pool.updates._background_update_progress_txn,
|
||||
"populate_user_directory_process_rooms",
|
||||
@@ -273,13 +266,12 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
|
||||
return processed_event_count
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _populate_user_directory_process_users(self, progress, batch_size):
|
||||
async def _populate_user_directory_process_users(self, progress, batch_size):
|
||||
"""
|
||||
If search_all_users is enabled, add all of the users to the user directory.
|
||||
"""
|
||||
if not self.hs.config.user_directory_search_all_users:
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
"populate_user_directory_process_users"
|
||||
)
|
||||
return 1
|
||||
@@ -305,13 +297,13 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
|
||||
return users_to_work_on
|
||||
|
||||
users_to_work_on = yield self.db_pool.runInteraction(
|
||||
users_to_work_on = await self.db_pool.runInteraction(
|
||||
"populate_user_directory_temp_read", _get_next_batch
|
||||
)
|
||||
|
||||
# No more users -- complete the transaction.
|
||||
if not users_to_work_on:
|
||||
yield self.db_pool.updates._end_background_update(
|
||||
await self.db_pool.updates._end_background_update(
|
||||
"populate_user_directory_process_users"
|
||||
)
|
||||
return 1
|
||||
@@ -322,18 +314,18 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
)
|
||||
|
||||
for user_id in users_to_work_on:
|
||||
profile = yield self.get_profileinfo(get_localpart_from_id(user_id))
|
||||
yield self.update_profile_in_user_dir(
|
||||
profile = await self.get_profileinfo(get_localpart_from_id(user_id))
|
||||
await self.update_profile_in_user_dir(
|
||||
user_id, profile.display_name, profile.avatar_url
|
||||
)
|
||||
|
||||
# We've finished processing a user. Delete it from the table.
|
||||
yield self.db_pool.simple_delete_one(
|
||||
await self.db_pool.simple_delete_one(
|
||||
TEMP_TABLE + "_users", {"user_id": user_id}
|
||||
)
|
||||
# Update the remaining counter.
|
||||
progress["remaining"] -= 1
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"populate_user_directory",
|
||||
self.db_pool.updates._background_update_progress_txn,
|
||||
"populate_user_directory_process_users",
|
||||
@@ -342,8 +334,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
|
||||
return len(users_to_work_on)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_room_world_readable_or_publicly_joinable(self, room_id):
|
||||
async def is_room_world_readable_or_publicly_joinable(self, room_id):
|
||||
"""Check if the room is either world_readable or publically joinable
|
||||
"""
|
||||
|
||||
@@ -353,20 +344,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
(EventTypes.RoomHistoryVisibility, ""),
|
||||
)
|
||||
|
||||
current_state_ids = yield self.get_filtered_current_state_ids(
|
||||
current_state_ids = await self.get_filtered_current_state_ids(
|
||||
room_id, StateFilter.from_types(types_to_filter)
|
||||
)
|
||||
|
||||
join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
|
||||
if join_rules_id:
|
||||
join_rule_ev = yield self.get_event(join_rules_id, allow_none=True)
|
||||
join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
|
||||
if join_rule_ev:
|
||||
if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
|
||||
return True
|
||||
|
||||
hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
|
||||
if hist_vis_id:
|
||||
hist_vis_ev = yield self.get_event(hist_vis_id, allow_none=True)
|
||||
hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
|
||||
if hist_vis_ev:
|
||||
if hist_vis_ev.content.get("history_visibility") == "world_readable":
|
||||
return True
|
||||
@@ -590,19 +581,18 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
"remove_from_user_dir", _remove_from_user_dir_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_users_in_dir_due_to_room(self, room_id):
|
||||
async def get_users_in_dir_due_to_room(self, room_id):
|
||||
"""Get all user_ids that are in the room directory because they're
|
||||
in the given room_id
|
||||
"""
|
||||
user_ids_share_pub = yield self.db_pool.simple_select_onecol(
|
||||
user_ids_share_pub = await self.db_pool.simple_select_onecol(
|
||||
table="users_in_public_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="user_id",
|
||||
desc="get_users_in_dir_due_to_room",
|
||||
)
|
||||
|
||||
user_ids_share_priv = yield self.db_pool.simple_select_onecol(
|
||||
user_ids_share_priv = await self.db_pool.simple_select_onecol(
|
||||
table="users_who_share_private_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="other_user_id",
|
||||
@@ -645,8 +635,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
"remove_user_who_share_room", _remove_user_who_share_room_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_dir_rooms_user_is_in(self, user_id):
|
||||
async def get_user_dir_rooms_user_is_in(self, user_id):
|
||||
"""
|
||||
Returns the rooms that a user is in.
|
||||
|
||||
@@ -656,14 +645,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
Returns:
|
||||
list: user_id
|
||||
"""
|
||||
rows = yield self.db_pool.simple_select_onecol(
|
||||
rows = await self.db_pool.simple_select_onecol(
|
||||
table="users_who_share_private_rooms",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="room_id",
|
||||
desc="get_rooms_user_is_in",
|
||||
)
|
||||
|
||||
pub_rows = yield self.db_pool.simple_select_onecol(
|
||||
pub_rows = await self.db_pool.simple_select_onecol(
|
||||
table="users_in_public_rooms",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="room_id",
|
||||
@@ -674,32 +663,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
users.update(rows)
|
||||
return list(users)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms_in_common_for_users(self, user_id, other_user_id):
|
||||
"""Given two user_ids find out the list of rooms they share.
|
||||
"""
|
||||
sql = """
|
||||
SELECT room_id FROM (
|
||||
SELECT c.room_id FROM current_state_events AS c
|
||||
INNER JOIN room_memberships AS m USING (event_id)
|
||||
WHERE type = 'm.room.member'
|
||||
AND m.membership = 'join'
|
||||
AND state_key = ?
|
||||
) AS f1 INNER JOIN (
|
||||
SELECT c.room_id FROM current_state_events AS c
|
||||
INNER JOIN room_memberships AS m USING (event_id)
|
||||
WHERE type = 'm.room.member'
|
||||
AND m.membership = 'join'
|
||||
AND state_key = ?
|
||||
) f2 USING (room_id)
|
||||
"""
|
||||
|
||||
rows = yield self.db_pool.execute(
|
||||
"get_rooms_in_common_for_users", None, sql, user_id, other_user_id
|
||||
)
|
||||
|
||||
return [room_id for room_id, in rows]
|
||||
|
||||
def get_user_directory_stream_pos(self):
|
||||
return self.db_pool.simple_select_one_onecol(
|
||||
table="user_directory_stream_pos",
|
||||
@@ -708,8 +671,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
desc="get_user_directory_stream_pos",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def search_user_dir(self, user_id, search_term, limit):
|
||||
async def search_user_dir(self, user_id, search_term, limit):
|
||||
"""Searches for users in directory
|
||||
|
||||
Returns:
|
||||
@@ -806,7 +768,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
|
||||
# This should be unreachable.
|
||||
raise Exception("Unrecognized database engine")
|
||||
|
||||
results = yield self.db_pool.execute(
|
||||
results = await self.db_pool.execute(
|
||||
"search_user_dir", self.db_pool.cursor_to_dict, sql, *args
|
||||
)
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
|
||||
|
||||
self.mock_as_api.query_alias.return_value = make_awaitable(True)
|
||||
self.mock_store.get_app_services.return_value = services
|
||||
self.mock_store.get_association_from_room_alias.return_value = defer.succeed(
|
||||
self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
|
||||
Mock(room_id=room_id, servers=servers)
|
||||
)
|
||||
|
||||
|
||||
@@ -34,8 +34,10 @@ class DirectoryStoreTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_room_to_alias(self):
|
||||
yield self.store.create_room_alias_association(
|
||||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
||||
yield defer.ensureDeferred(
|
||||
self.store.create_room_alias_association(
|
||||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
@@ -45,24 +47,36 @@ class DirectoryStoreTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_alias_to_room(self):
|
||||
yield self.store.create_room_alias_association(
|
||||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
||||
yield defer.ensureDeferred(
|
||||
self.store.create_room_alias_association(
|
||||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
||||
)
|
||||
)
|
||||
|
||||
self.assertObjectHasAttributes(
|
||||
{"room_id": self.room.to_string(), "servers": ["test"]},
|
||||
(yield self.store.get_association_from_room_alias(self.alias)),
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_association_from_room_alias(self.alias)
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_delete_alias(self):
|
||||
yield self.store.create_room_alias_association(
|
||||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
||||
yield defer.ensureDeferred(
|
||||
self.store.create_room_alias_association(
|
||||
room_alias=self.alias, room_id=self.room.to_string(), servers=["test"]
|
||||
)
|
||||
)
|
||||
|
||||
room_id = yield self.store.delete_room_alias(self.alias)
|
||||
room_id = yield defer.ensureDeferred(self.store.delete_room_alias(self.alias))
|
||||
self.assertEqual(self.room.to_string(), room_id)
|
||||
|
||||
self.assertIsNone(
|
||||
(yield self.store.get_association_from_room_alias(self.alias))
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_association_from_room_alias(self.alias)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -34,7 +34,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
||||
|
||||
yield self.store.set_e2e_device_keys("user", "device", now, json)
|
||||
|
||||
res = yield self.store.get_e2e_device_keys((("user", "device"),))
|
||||
res = yield defer.ensureDeferred(
|
||||
self.store.get_e2e_device_keys((("user", "device"),))
|
||||
)
|
||||
self.assertIn("user", res)
|
||||
self.assertIn("device", res["user"])
|
||||
dev = res["user"]["device"]
|
||||
@@ -63,7 +65,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
||||
yield self.store.set_e2e_device_keys("user", "device", now, json)
|
||||
yield self.store.store_device("user", "device", "display_name")
|
||||
|
||||
res = yield self.store.get_e2e_device_keys((("user", "device"),))
|
||||
res = yield defer.ensureDeferred(
|
||||
self.store.get_e2e_device_keys((("user", "device"),))
|
||||
)
|
||||
self.assertIn("user", res)
|
||||
self.assertIn("device", res["user"])
|
||||
dev = res["user"]["device"]
|
||||
@@ -85,8 +89,8 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
|
||||
yield self.store.set_e2e_device_keys("user2", "device1", now, {"key": "json21"})
|
||||
yield self.store.set_e2e_device_keys("user2", "device2", now, {"key": "json22"})
|
||||
|
||||
res = yield self.store.get_e2e_device_keys(
|
||||
(("user1", "device1"), ("user2", "device2"))
|
||||
res = yield defer.ensureDeferred(
|
||||
self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
|
||||
)
|
||||
self.assertIn("user1", res)
|
||||
self.assertIn("device1", res["user1"])
|
||||
|
||||
@@ -19,6 +19,7 @@ from twisted.internet import defer
|
||||
from synapse.api.constants import UserTypes
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
from tests.unittest import default_config, override_config
|
||||
|
||||
FORTY_DAYS = 40 * 24 * 60 * 60
|
||||
@@ -230,7 +231,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.get_success(d)
|
||||
|
||||
self.store.upsert_monthly_active_user = Mock()
|
||||
self.store.upsert_monthly_active_user = Mock(
|
||||
side_effect=lambda user_id: make_awaitable(None)
|
||||
)
|
||||
|
||||
d = self.store.populate_monthly_active_users(user_id)
|
||||
self.get_success(d)
|
||||
@@ -238,7 +241,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||
self.store.upsert_monthly_active_user.assert_not_called()
|
||||
|
||||
def test_populate_monthly_users_should_update(self):
|
||||
self.store.upsert_monthly_active_user = Mock()
|
||||
self.store.upsert_monthly_active_user = Mock(
|
||||
side_effect=lambda user_id: make_awaitable(None)
|
||||
)
|
||||
|
||||
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
|
||||
|
||||
@@ -251,7 +256,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||
self.store.upsert_monthly_active_user.assert_called_once()
|
||||
|
||||
def test_populate_monthly_users_should_not_update(self):
|
||||
self.store.upsert_monthly_active_user = Mock()
|
||||
self.store.upsert_monthly_active_user = Mock(
|
||||
side_effect=lambda user_id: make_awaitable(None)
|
||||
)
|
||||
|
||||
self.store.is_trial_user = Mock(return_value=defer.succeed(False))
|
||||
self.store.user_last_seen_monthly_active = Mock(
|
||||
@@ -333,7 +340,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
|
||||
def test_no_users_when_not_tracking(self):
|
||||
self.store.upsert_monthly_active_user = Mock()
|
||||
self.store.upsert_monthly_active_user = Mock(
|
||||
side_effect=lambda user_id: make_awaitable(None)
|
||||
)
|
||||
|
||||
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
|
||||
def test_search_user_dir(self):
|
||||
# normally when alice searches the directory she should just find
|
||||
# bob because bobby doesn't share a room with her.
|
||||
r = yield self.store.search_user_dir(ALICE, "bob", 10)
|
||||
r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
|
||||
self.assertFalse(r["limited"])
|
||||
self.assertEqual(1, len(r["results"]))
|
||||
self.assertDictEqual(
|
||||
@@ -51,7 +51,7 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
|
||||
def test_search_user_dir_all_users(self):
|
||||
self.hs.config.user_directory_search_all_users = True
|
||||
try:
|
||||
r = yield self.store.search_user_dir(ALICE, "bob", 10)
|
||||
r = yield defer.ensureDeferred(self.store.search_user_dir(ALICE, "bob", 10))
|
||||
self.assertFalse(r["limited"])
|
||||
self.assertEqual(2, len(r["results"]))
|
||||
self.assertDictEqual(
|
||||
|
||||
Reference in New Issue
Block a user