1
0

apply changes from review

This commit is contained in:
Hubert Chathi
2019-12-06 12:31:37 -05:00
parent 74288a793a
commit 44d4e4d7cb

View File

@@ -14,10 +14,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable
from six import iteritems
from canonicaljson import encode_canonical_json, json
from twisted.enterprise.adbapi import Connection
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -333,8 +336,12 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
def _get_e2e_cross_signing_keys_bulk_txn(
self, txn, users, key_type, from_user_id=None
):
self,
txn: Connection,
users: Iterable[str],
key_type: str,
from_user_id: str = None,
) -> dict:
"""Returns the cross-signing keys for a set of users.
Args:
@@ -347,68 +354,82 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
the keys will be included in the result
Returns:
dict mapping from user ID to key data. If a user's cross-signing key
was not found, their user ID will not be in the dict.
dict[str, dict]: mapping from user ID to key data. If a user's
cross-signing key was not found, their user ID will not be in the
dict.
"""
sql = (
"SELECT k.user_id, k.keydata, k.stream_id "
" FROM e2e_cross_signing_keys k"
" JOIN (SELECT user_id, MAX(stream_id) AS stream_id"
" FROM e2e_cross_signing_keys"
" WHERE keytype = ?"
" GROUP BY user_id) s"
" USING (user_id, stream_id)"
" WHERE k.user_id IN (%s) AND k.keytype = ?"
) % (",".join("?" for u in users))
query_params = [key_type]
query_params.extend(users)
query_params.append(key_type)
txn.execute(sql, query_params)
rows = self.cursor_to_dict(txn)
# convert to a list if needed, so that we can slice it
if not isinstance(users, list):
users = list(users)
result = {}
devices = {}
for row in rows:
user_id = row["user_id"]
key = json.loads(row["keydata"])
result[user_id] = key
for k in key["keys"].values():
devices[user_id] = k
if devices and from_user_id:
# if we're asked to get signatures, and we have any devices to get
# signatures for, fetch the signatures
sql = (
"SELECT target_user_id, key_id, signature "
" FROM e2e_cross_signing_signatures "
" WHERE user_id = ?"
" AND (%s)"
) % (
" OR ".join(
"(target_user_id = ? AND target_device_id = ?)" for d in devices
)
batch_size = 100
chunks = [users[i : i + batch_size] for i in range(0, len(users), batch_size)]
for user_chunk in chunks:
sql = """
SELECT k.user_id, k.keydata, k.stream_id
FROM e2e_cross_signing_keys k
INNER JOIN (SELECT user_id, MAX(stream_id) AS stream_id
FROM e2e_cross_signing_keys
WHERE keytype = ?
GROUP BY user_id) s
USING (user_id, stream_id)
WHERE k.user_id IN (%s) AND k.keytype = ?
""" % (
",".join("?" for u in user_chunk)
)
query_params = [from_user_id]
for item in devices.items():
# item is a (user_id, device_id) tuple
query_params.extend(item)
query_params = [key_type]
query_params.extend(user_chunk)
query_params.append(key_type)
txn.execute(sql, query_params)
rows = self.cursor_to_dict(txn)
# and add the signatures to the appropriate keys
devices = {}
for row in rows:
key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_user_key = result[target_user_id]
signatures = target_user_key.setdefault("signatures", {})
user_sigs = signatures.setdefault(from_user_id, {})
user_sigs[key_id] = row["signature"]
user_id = row["user_id"]
key = json.loads(row["keydata"])
result[user_id] = key
for k in key["keys"].values():
devices[user_id] = k
if devices and from_user_id:
# if we're asked to get signatures, and we have any devices to get
# signatures for, fetch the signatures
sql = """
SELECT target_user_id, key_id, signature
FROM e2e_cross_signing_signatures
WHERE user_id = ?
AND (%s)
""" % (
" OR ".join(
"(target_user_id = ? AND target_device_id = ?)" for d in devices
)
)
query_params = [from_user_id]
for item in devices.items():
# item is a (user_id, device_id) tuple
query_params.extend(item)
txn.execute(sql, query_params)
rows = self.cursor_to_dict(txn)
# and add the signatures to the appropriate keys
for row in rows:
key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_user_key = result[target_user_id]
signatures = target_user_key.setdefault("signatures", {})
user_sigs = signatures.setdefault(from_user_id, {})
user_sigs[key_id] = row["signature"]
return result
def get_e2e_cross_signing_keys_bulk(self, users, key_type, from_user_id=None):
def get_e2e_cross_signing_keys_bulk(
self, users: Iterable[str], key_type: str, from_user_id: str = None
) -> defer.Deferred:
"""Returns the cross-signing keys for a set of users.
Args:
@@ -420,8 +441,9 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
the self-signing keys will be included in the result
Returns:
Deferred[dict] map of user ID to key data. If a user's cross-signing key
was not found, their user ID will not be in the dict.
Deferred[dict[str, dict]]: map of user ID to key data. If a user's
cross-signing key was not found, their user ID will not be in
the dict.
"""
return self.runInteraction(
"get_e2e_cross_signing_key",