apply changes from review
This commit is contained in:
@@ -14,10 +14,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Iterable
|
||||
|
||||
from six import iteritems
|
||||
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
|
||||
from twisted.enterprise.adbapi import Connection
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
@@ -333,8 +336,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
def _get_e2e_cross_signing_keys_bulk_txn(
|
||||
self, txn, users, key_type, from_user_id=None
|
||||
):
|
||||
self,
|
||||
txn: Connection,
|
||||
users: Iterable[str],
|
||||
key_type: str,
|
||||
from_user_id: str = None,
|
||||
) -> dict:
|
||||
"""Returns the cross-signing keys for a set of users.
|
||||
|
||||
Args:
|
||||
@@ -347,68 +354,82 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
the keys will be included in the result
|
||||
|
||||
Returns:
|
||||
dict mapping from user ID to key data. If a user's cross-signing key
|
||||
was not found, their user ID will not be in the dict.
|
||||
dict[str, dict]: mapping from user ID to key data. If a user's
|
||||
cross-signing key was not found, their user ID will not be in the
|
||||
dict.
|
||||
"""
|
||||
sql = (
|
||||
"SELECT k.user_id, k.keydata, k.stream_id "
|
||||
" FROM e2e_cross_signing_keys k"
|
||||
" JOIN (SELECT user_id, MAX(stream_id) AS stream_id"
|
||||
" FROM e2e_cross_signing_keys"
|
||||
" WHERE keytype = ?"
|
||||
" GROUP BY user_id) s"
|
||||
" USING (user_id, stream_id)"
|
||||
" WHERE k.user_id IN (%s) AND k.keytype = ?"
|
||||
) % (",".join("?" for u in users))
|
||||
query_params = [key_type]
|
||||
query_params.extend(users)
|
||||
query_params.append(key_type)
|
||||
|
||||
txn.execute(sql, query_params)
|
||||
rows = self.cursor_to_dict(txn)
|
||||
# convert to a list if needed, so that we can slice it
|
||||
if not isinstance(users, list):
|
||||
users = list(users)
|
||||
|
||||
result = {}
|
||||
devices = {}
|
||||
for row in rows:
|
||||
user_id = row["user_id"]
|
||||
key = json.loads(row["keydata"])
|
||||
result[user_id] = key
|
||||
for k in key["keys"].values():
|
||||
devices[user_id] = k
|
||||
|
||||
if devices and from_user_id:
|
||||
# if we're asked to get signatures, and we have any devices to get
|
||||
# signatures for, fetch the signatures
|
||||
sql = (
|
||||
"SELECT target_user_id, key_id, signature "
|
||||
" FROM e2e_cross_signing_signatures "
|
||||
" WHERE user_id = ?"
|
||||
" AND (%s)"
|
||||
) % (
|
||||
" OR ".join(
|
||||
"(target_user_id = ? AND target_device_id = ?)" for d in devices
|
||||
)
|
||||
batch_size = 100
|
||||
chunks = [users[i : i + batch_size] for i in range(0, len(users), batch_size)]
|
||||
for user_chunk in chunks:
|
||||
sql = """
|
||||
SELECT k.user_id, k.keydata, k.stream_id
|
||||
FROM e2e_cross_signing_keys k
|
||||
INNER JOIN (SELECT user_id, MAX(stream_id) AS stream_id
|
||||
FROM e2e_cross_signing_keys
|
||||
WHERE keytype = ?
|
||||
GROUP BY user_id) s
|
||||
USING (user_id, stream_id)
|
||||
WHERE k.user_id IN (%s) AND k.keytype = ?
|
||||
""" % (
|
||||
",".join("?" for u in user_chunk)
|
||||
)
|
||||
query_params = [from_user_id]
|
||||
for item in devices.items():
|
||||
# item is a (user_id, device_id) tuple
|
||||
query_params.extend(item)
|
||||
query_params = [key_type]
|
||||
query_params.extend(user_chunk)
|
||||
query_params.append(key_type)
|
||||
|
||||
txn.execute(sql, query_params)
|
||||
rows = self.cursor_to_dict(txn)
|
||||
|
||||
# and add the signatures to the appropriate keys
|
||||
devices = {}
|
||||
for row in rows:
|
||||
key_id = row["key_id"]
|
||||
target_user_id = row["target_user_id"]
|
||||
target_user_key = result[target_user_id]
|
||||
signatures = target_user_key.setdefault("signatures", {})
|
||||
user_sigs = signatures.setdefault(from_user_id, {})
|
||||
user_sigs[key_id] = row["signature"]
|
||||
user_id = row["user_id"]
|
||||
key = json.loads(row["keydata"])
|
||||
result[user_id] = key
|
||||
for k in key["keys"].values():
|
||||
devices[user_id] = k
|
||||
|
||||
if devices and from_user_id:
|
||||
# if we're asked to get signatures, and we have any devices to get
|
||||
# signatures for, fetch the signatures
|
||||
sql = """
|
||||
SELECT target_user_id, key_id, signature
|
||||
FROM e2e_cross_signing_signatures
|
||||
WHERE user_id = ?
|
||||
AND (%s)
|
||||
""" % (
|
||||
" OR ".join(
|
||||
"(target_user_id = ? AND target_device_id = ?)" for d in devices
|
||||
)
|
||||
)
|
||||
query_params = [from_user_id]
|
||||
for item in devices.items():
|
||||
# item is a (user_id, device_id) tuple
|
||||
query_params.extend(item)
|
||||
|
||||
txn.execute(sql, query_params)
|
||||
rows = self.cursor_to_dict(txn)
|
||||
|
||||
# and add the signatures to the appropriate keys
|
||||
for row in rows:
|
||||
key_id = row["key_id"]
|
||||
target_user_id = row["target_user_id"]
|
||||
target_user_key = result[target_user_id]
|
||||
signatures = target_user_key.setdefault("signatures", {})
|
||||
user_sigs = signatures.setdefault(from_user_id, {})
|
||||
user_sigs[key_id] = row["signature"]
|
||||
|
||||
return result
|
||||
|
||||
def get_e2e_cross_signing_keys_bulk(self, users, key_type, from_user_id=None):
|
||||
def get_e2e_cross_signing_keys_bulk(
|
||||
self, users: Iterable[str], key_type: str, from_user_id: str = None
|
||||
) -> defer.Deferred:
|
||||
"""Returns the cross-signing keys for a set of users.
|
||||
|
||||
Args:
|
||||
@@ -420,8 +441,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
the self-signing keys will be included in the result
|
||||
|
||||
Returns:
|
||||
Deferred[dict] map of user ID to key data. If a user's cross-signing key
|
||||
was not found, their user ID will not be in the dict.
|
||||
Deferred[dict[str, dict]]: map of user ID to key data. If a user's
|
||||
cross-signing key was not found, their user ID will not be in
|
||||
the dict.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_e2e_cross_signing_key",
|
||||
|
||||
Reference in New Issue
Block a user