Add account_keys table to store all key<->name mappings
This commit is contained in:
@@ -32,6 +32,7 @@ from synapse.storage.database import (
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.account_keys import AccountKeysStore
|
||||
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
|
||||
from synapse.storage.databases.main.stats import UserSortOrder
|
||||
from synapse.storage.databases.main.thread_subscriptions import (
|
||||
@@ -163,6 +164,7 @@ class DataStore(
|
||||
TaskSchedulerWorkerStore,
|
||||
SlidingSyncStore,
|
||||
DelayedEventsStore,
|
||||
AccountKeysStore,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
166
synapse/storage/databases/main/account_keys.py
Normal file
166
synapse/storage/databases/main/account_keys.py
Normal file
@@ -0,0 +1,166 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
# Originally licensed under the Apache License, Version 2.0:
|
||||
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||
#
|
||||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
|
||||
from typing import TYPE_CHECKING, Collection, Dict, List, Tuple, cast
|
||||
|
||||
from signedjson.key import (
|
||||
decode_signing_key_base64,
|
||||
generate_signing_key,
|
||||
get_verify_key,
|
||||
)
|
||||
from signedjson.types import SigningKey
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.types import get_domain_from_id, get_localpart_from_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
|
||||
class AccountKeysStore(SQLBaseStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
async def get_or_create_account_key_user_id_for_account_name_user_id(
|
||||
self, account_name_user_id: str
|
||||
) -> Tuple[str, SigningKey]:
|
||||
"""
|
||||
Get or create an account key for the given account name user ID.
|
||||
The user ID must belong to this server.
|
||||
|
||||
Args:
|
||||
account_name_user_id: An account name user ID e.g "@alice:example.com"
|
||||
Returns:
|
||||
A tuple of account key user ID e.g @l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ:example.com
|
||||
and the private key for the account.
|
||||
Raises:
|
||||
if the provided account name user ID is not owned by this homeserver, or if the user
|
||||
ID is invalid in some way.
|
||||
"""
|
||||
if not self.hs.is_mine_id(account_name_user_id):
|
||||
raise SynapseError(
|
||||
500,
|
||||
(
|
||||
"get_or_create_account_key_user_id_for_account_name_user_id: this server cannot"
|
||||
f" create an account key for other servers: {account_name_user_id}"
|
||||
),
|
||||
)
|
||||
|
||||
row = await self.db_pool.simple_select_one(
|
||||
table="account_keys",
|
||||
keyvalues={
|
||||
"account_name_user_id": account_name_user_id,
|
||||
},
|
||||
retcols=["account_key_user_id", "account_key"],
|
||||
allow_none=True,
|
||||
desc="get_or_create_account_key_user_id_for_account_name_user_id.get_key_txn",
|
||||
)
|
||||
if row is not None:
|
||||
return row[0], decode_account_key(row[1], get_localpart_from_id(row[0]))
|
||||
|
||||
# create a new account key for this account inside a txn to ensure we lock correctly.
|
||||
def create_key_txn(txn: LoggingTransaction) -> Tuple[str, str]:
|
||||
key = generate_account_key()
|
||||
account_key_user_id = (
|
||||
f"@{key.version}:{get_domain_from_id(account_name_user_id)}"
|
||||
)
|
||||
|
||||
# Race to insert the key. The first one to make it will be returned here as we don't clobber
|
||||
sql = (
|
||||
"INSERT INTO account_keys(account_name_user_id, account_key_user_id, account_key)"
|
||||
" VALUES(?, ?, ?)"
|
||||
" ON CONFLICT DO NOTHING"
|
||||
)
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
account_name_user_id,
|
||||
account_key_user_id,
|
||||
encode_base64(key.encode(), urlsafe=True),
|
||||
),
|
||||
)
|
||||
sql = "SELECT account_key_user_id, account_key FROM account_keys WHERE account_name_user_id = ?"
|
||||
txn.execute(sql, (account_name_user_id,))
|
||||
return cast(Tuple[str, str], txn.fetchone())
|
||||
|
||||
row = await self.db_pool.runInteraction(
|
||||
"get_or_create_account_key_user_id_for_account_name_user_id.create_key_txn",
|
||||
create_key_txn,
|
||||
)
|
||||
return row[0], decode_account_key(row[1], get_localpart_from_id(row[0]))
|
||||
|
||||
async def get_account_name_user_ids_for_account_key_user_ids(
|
||||
self,
|
||||
account_key_user_ids: Collection[str],
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Fetch the verified account name user IDs for the given account key user IDs. Unknown account key
|
||||
user IDs will be omitted from the dict.
|
||||
|
||||
Args:
|
||||
account_key_user_ids: A list of user IDs in account key format e.g
|
||||
["@l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ:example.com"]
|
||||
|
||||
Returns:
|
||||
A map of account key user IDs to account name user IDs e.g.
|
||||
{"@l8Hft5qXKn1vfHrg3p4+W8gELQVo8N13JkluMfmn2sQ:example.com":"@alice:example.com"}
|
||||
"""
|
||||
|
||||
clause, args = make_in_list_sql_clause(
|
||||
self.database_engine, "account_key_user_id", account_key_user_ids
|
||||
)
|
||||
|
||||
def f(txn: LoggingTransaction) -> List[Tuple[str, str]]:
|
||||
sql = f"SELECT account_key_user_id, account_name_user_id FROM account_keys WHERE {clause} AND account_name_user_id IS NOT NULL"
|
||||
txn.execute(sql, args)
|
||||
return cast(List[Tuple[str, str]], txn.fetchall())
|
||||
|
||||
rows = await self.db_pool.runInteraction(
|
||||
"get_account_name_user_ids_for_account_key_user_ids", f
|
||||
)
|
||||
return {row[0]: row[1] for row in rows}
|
||||
|
||||
|
||||
def generate_account_key() -> SigningKey:
|
||||
signing_key = generate_signing_key("1") # '1' will be replaced with the public key
|
||||
verify_key_str = encode_base64(get_verify_key(signing_key).encode(), urlsafe=True)
|
||||
signing_key.version = verify_key_str
|
||||
return signing_key
|
||||
|
||||
|
||||
def decode_account_key(signing_key: str, verify_key: str) -> SigningKey:
|
||||
return decode_signing_key_base64(
|
||||
"ed25519",
|
||||
verify_key,
|
||||
signing_key,
|
||||
)
|
||||
@@ -19,7 +19,7 @@
|
||||
#
|
||||
#
|
||||
|
||||
SCHEMA_VERSION = 92 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 93 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
|
||||
25
synapse/storage/schema/main/delta/93/01_account_keys.sql
Normal file
25
synapse/storage/schema/main/delta/93/01_account_keys.sql
Normal file
@@ -0,0 +1,25 @@
|
||||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2025 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
-- Keeps a record of MSC4243 account key <--> account name mappings for all servers.
|
||||
-- This mapping is permanent.
|
||||
CREATE TABLE account_keys (
|
||||
account_key_user_id TEXT PRIMARY KEY NOT NULL,
|
||||
-- nullable if we cannot talk to the remote server.
|
||||
account_name_user_id TEXT,
|
||||
-- the private key as urlsafe base64, only for local accounts
|
||||
account_key TEXT,
|
||||
UNIQUE(account_key_user_id, account_name_user_id)
|
||||
);
|
||||
|
||||
CREATE INDEX account_keys_key_for_name ON account_keys (account_name_user_id) WHERE account_name_user_id IS NOT NULL;
|
||||
94
tests/storage/test_account_keys.py
Normal file
94
tests/storage/test_account_keys.py
Normal file
@@ -0,0 +1,94 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
#
|
||||
# Originally licensed under the Apache License, Version 2.0:
|
||||
# <http://www.apache.org/licenses/LICENSE-2.0>.
|
||||
#
|
||||
# [This file includes modifications made by New Vector Limited]
|
||||
#
|
||||
#
|
||||
|
||||
|
||||
from signedjson.key import get_verify_key
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import get_localpart_from_id
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class AccountKeysTestCase(unittest.HomeserverTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.store = self.hs.get_datastores().main
|
||||
self.user = "@user:test"
|
||||
|
||||
def test_get_or_create_account_key_user_id_for_account_name_user_id(self) -> None:
|
||||
key_user_id, key = self.get_success(
|
||||
self.store.get_or_create_account_key_user_id_for_account_name_user_id(
|
||||
self.user
|
||||
)
|
||||
)
|
||||
# asserts the localpart is unpadded urlsafe base64
|
||||
self.assertRegex(key_user_id, r"^@[A-Za-z0-9\-_]{43}:test$")
|
||||
# asserts the key ID is the localpart
|
||||
self.assertEquals(key.version, get_localpart_from_id(key_user_id))
|
||||
# asserts the key ID is the public key
|
||||
self.assertEquals(
|
||||
key.version, encode_base64(get_verify_key(key).encode(), urlsafe=True)
|
||||
)
|
||||
# assert that repeated calls return the same key
|
||||
key_user_id2, key2 = self.get_success(
|
||||
self.store.get_or_create_account_key_user_id_for_account_name_user_id(
|
||||
self.user
|
||||
)
|
||||
)
|
||||
self.assertEquals(key_user_id, key_user_id2)
|
||||
self.assertEquals(key.encode(), key2.encode())
|
||||
|
||||
def test_get_account_name_user_ids_for_account_key_user_ids(self) -> None:
|
||||
key_user_id, key = self.get_success(
|
||||
self.store.get_or_create_account_key_user_id_for_account_name_user_id(
|
||||
self.user,
|
||||
)
|
||||
)
|
||||
result = self.get_success(
|
||||
self.store.get_account_name_user_ids_for_account_key_user_ids(
|
||||
[key_user_id]
|
||||
),
|
||||
)
|
||||
self.assertEquals(result[key_user_id], self.user)
|
||||
|
||||
def test_get_account_name_user_ids_for_account_key_user_ids_multiple(self) -> None:
|
||||
key_user_id_alice, _ = self.get_success(
|
||||
self.store.get_or_create_account_key_user_id_for_account_name_user_id(
|
||||
"@alice:test",
|
||||
)
|
||||
)
|
||||
key_user_id_bob, _ = self.get_success(
|
||||
self.store.get_or_create_account_key_user_id_for_account_name_user_id(
|
||||
"@bob:test",
|
||||
)
|
||||
)
|
||||
key_user_id_unknown = "@6fey6W1wS3-vbvUmHZnTd6Gi3o-TIxvIcwtEQP4nrW0:test"
|
||||
result = self.get_success(
|
||||
self.store.get_account_name_user_ids_for_account_key_user_ids(
|
||||
[key_user_id_alice, key_user_id_bob, key_user_id_unknown]
|
||||
),
|
||||
)
|
||||
self.assertEquals(result[key_user_id_alice], "@alice:test")
|
||||
self.assertEquals(result[key_user_id_bob], "@bob:test")
|
||||
self.assertEquals(result.get(key_user_id_unknown, None), None)
|
||||
Reference in New Issue
Block a user