Use ed25519:1 as the key ID

This commit is contained in:
Kegan Dougal
2025-09-10 15:35:37 +01:00
parent c0ffe61adb
commit b338f886d6
2 changed files with 14 additions and 17 deletions

View File

@@ -86,13 +86,13 @@ class AccountKeysStore(SQLBaseStore):
desc="get_or_create_account_key_user_id_for_account_name_user_id.get_key_txn", desc="get_or_create_account_key_user_id_for_account_name_user_id.get_key_txn",
) )
if row is not None: if row is not None:
return row[0], decode_account_key(row[1], get_localpart_from_id(row[0])) return row[0], decode_account_key(row[1])
# create a new account key for this account inside a txn to ensure we lock correctly. # create a new account key for this account inside a txn to ensure we lock correctly.
def create_key_txn(txn: LoggingTransaction) -> Tuple[str, str]: def create_key_txn(txn: LoggingTransaction) -> Tuple[str, str]:
key = generate_account_key() key, public_key_str = generate_account_key()
account_key_user_id = ( account_key_user_id = (
f"@{key.version}:{get_domain_from_id(account_name_user_id)}" f"@{public_key_str}:{get_domain_from_id(account_name_user_id)}"
) )
# Race to insert the key. The first one to make it will be returned here as we don't clobber # Race to insert the key. The first one to make it will be returned here as we don't clobber
@@ -117,7 +117,7 @@ class AccountKeysStore(SQLBaseStore):
"get_or_create_account_key_user_id_for_account_name_user_id.create_key_txn", "get_or_create_account_key_user_id_for_account_name_user_id.create_key_txn",
create_key_txn, create_key_txn,
) )
return row[0], decode_account_key(row[1], get_localpart_from_id(row[0])) return row[0], decode_account_key(row[1])
async def get_account_name_user_ids_for_account_key_user_ids( async def get_account_name_user_ids_for_account_key_user_ids(
self, self,
@@ -151,16 +151,15 @@ class AccountKeysStore(SQLBaseStore):
return {row[0]: row[1] for row in rows} return {row[0]: row[1] for row in rows}
def generate_account_key() -> SigningKey: def generate_account_key() -> Tuple[SigningKey, str]:
signing_key = generate_signing_key("1") # '1' will be replaced with the public key signing_key = generate_signing_key("1")
verify_key_str = encode_base64(get_verify_key(signing_key).encode(), urlsafe=True) verify_key_str = encode_base64(get_verify_key(signing_key).encode(), urlsafe=True)
signing_key.version = verify_key_str return signing_key, verify_key_str
return signing_key
def decode_account_key(signing_key: str, verify_key: str) -> SigningKey: def decode_account_key(signing_key: str) -> SigningKey:
return decode_signing_key_base64( return decode_signing_key_base64(
"ed25519", "ed25519",
verify_key, "1",
signing_key, signing_key,
) )

View File

@@ -44,12 +44,10 @@ class AccountKeysTestCase(unittest.HomeserverTestCase):
) )
# asserts the localpart is unpadded urlsafe base64 # asserts the localpart is unpadded urlsafe base64
self.assertRegex(key_user_id, r"^@[A-Za-z0-9\-_]{43}:test$") self.assertRegex(key_user_id, r"^@[A-Za-z0-9\-_]{43}:test$")
# asserts the key ID is the localpart # asserts the public key is the localpart
self.assertEquals(key.version, get_localpart_from_id(key_user_id)) self.assertEquals(encode_base64(get_verify_key(key).encode(), urlsafe=True), get_localpart_from_id(key_user_id))
# asserts the key ID is the public key # asserts the key ID is 1
self.assertEquals( self.assertEquals(key.version, "1")
key.version, encode_base64(get_verify_key(key).encode(), urlsafe=True)
)
# assert that repeated calls return the same key # assert that repeated calls return the same key
key_user_id2, key2 = self.get_success( key_user_id2, key2 = self.get_success(
self.store.get_or_create_account_key_user_id_for_account_name_user_id( self.store.get_or_create_account_key_user_id_for_account_name_user_id(
@@ -60,7 +58,7 @@ class AccountKeysTestCase(unittest.HomeserverTestCase):
self.assertEquals(key.encode(), key2.encode()) self.assertEquals(key.encode(), key2.encode())
def test_get_account_name_user_ids_for_account_key_user_ids(self) -> None: def test_get_account_name_user_ids_for_account_key_user_ids(self) -> None:
key_user_id, key = self.get_success( key_user_id, _ = self.get_success(
self.store.get_or_create_account_key_user_id_for_account_name_user_id( self.store.get_or_create_account_key_user_id_for_account_name_user_id(
self.user, self.user,
) )