1
0
This commit is contained in:
Brendan Abolivier
2021-12-06 17:09:10 +00:00
parent 996bed00f8
commit 94191b9151
6 changed files with 27 additions and 56 deletions

View File

@@ -314,13 +314,12 @@ class ProfileHandler:
authenticated_entity=requester.authenticated_entity,
)
new_batchnum = None
if len(self.replicate_user_profiles_to) > 0:
cur_batchnum = (
await self.store.get_latest_profile_replication_batch_number()
)
new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
else:
new_batchnum = None
await self.store.set_profile_displayname(
target_user.localpart, displayname_to_set, new_batchnum
@@ -366,13 +365,12 @@ class ProfileHandler:
False and active is False, user will have their profile
erased
"""
new_batchnum = None
if len(self.replicate_user_profiles_to) > 0:
cur_batchnum = (
await self.store.get_latest_profile_replication_batch_number()
)
new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
else:
new_batchnum = None
await self.store.set_profiles_active(users, active, hide, new_batchnum)
@@ -487,13 +485,12 @@ class ProfileHandler:
target_user, authenticated_entity=requester.authenticated_entity
)
new_batchnum = None
if len(self.replicate_user_profiles_to) > 0:
cur_batchnum = (
await self.store.get_latest_profile_replication_batch_number()
)
new_batchnum = 0 if cur_batchnum is None else cur_batchnum + 1
else:
new_batchnum = None
await self.store.set_profile_avatar_url(
target_user.localpart, avatar_url_to_set, new_batchnum

View File

@@ -45,7 +45,7 @@ class AccountDataServlet(RestServlet):
self.store = hs.get_datastore()
self.handler = hs.get_account_data_handler()
self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker_app is not None
self._is_worker = hs.config.worker.worker_app is not None
self._profile_handler = hs.get_profile_handler()
async def on_PUT(

View File

@@ -652,7 +652,9 @@ class RegisterRestServlet(RestServlet):
# something else went wrong.
break
if self.hs.config.register_just_use_email_for_display_name:
if (
self.hs.config.registration.register_just_use_email_for_display_name
):
desired_display_name = address
else:
# Custom mapping between email address and display name

View File

@@ -72,7 +72,7 @@ class UserDirectorySearchRestServlet(RestServlet):
if self.hs.config.userdirectory.user_directory_defer_to_id_server:
signed_body = sign_json(
body, self.hs.hostname, self.hs.config.signing_key[0]
body, self.hs.hostname, self.hs.config.key.signing_key[0]
)
url = "%s/_matrix/identity/api/v1/user_directory/search" % (
self.hs.config.userdirectory.user_directory_defer_to_id_server,

View File

@@ -62,8 +62,8 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
async def get_latest_profile_replication_batch_number(self):
def f(txn):
async def get_latest_profile_replication_batch_number(self) -> Optional[int]:
def f(txn: LoggingTransaction) -> Optional[int]:
txn.execute("SELECT MAX(batch) as maxbatch FROM profiles")
rows = self.db_pool.cursor_to_dict(txn)
return rows[0]["maxbatch"]
@@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore):
"get_latest_profile_replication_batch_number", f
)
async def get_profile_batch(self, batchnum):
async def get_profile_batch(self, batchnum: int) -> List[Dict[str, Any]]:
return await self.db_pool.simple_select_list(
table="profiles",
keyvalues={"batch": batchnum},
@@ -80,8 +80,8 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_batch",
)
async def assign_profile_batch(self):
def f(txn):
async def assign_profile_batch(self) -> int:
def f(txn: LoggingTransaction) -> int:
sql = (
"UPDATE profiles SET batch = "
"(SELECT COALESCE(MAX(batch), -1) + 1 FROM profiles) "
@@ -94,8 +94,8 @@ class ProfileWorkerStore(SQLBaseStore):
return await self.db_pool.runInteraction("assign_profile_batch", f)
async def get_replication_hosts(self):
def f(txn):
async def get_replication_hosts(self) -> Dict[str, int]:
def f(txn: LoggingTransaction) -> Dict[str, int]:
txn.execute(
"SELECT host, last_synced_batch FROM profile_replication_status"
)
@@ -106,7 +106,7 @@ class ProfileWorkerStore(SQLBaseStore):
async def update_replication_batch_for_host(
self, host: str, last_synced_batch: int
):
) -> bool:
return await self.db_pool.simple_upsert(
table="profile_replication_status",
keyvalues={"host": host},
@@ -131,7 +131,10 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def set_profile_displayname(
self, user_localpart: str, new_displayname: Optional[str], batchnum: int
self,
user_localpart: str,
new_displayname: Optional[str],
batchnum: Optional[int],
) -> None:
# Invalidate the read cache for this user
self.get_profile_displayname.invalidate((user_localpart,))
@@ -145,7 +148,10 @@ class ProfileWorkerStore(SQLBaseStore):
)
async def set_profile_avatar_url(
self, user_localpart: str, new_avatar_url: Optional[str], batchnum: int
self,
user_localpart: str,
new_avatar_url: Optional[str],
batchnum: Optional[int],
) -> None:
# Invalidate the read cache for this user
self.get_profile_avatar_url.invalidate((user_localpart,))
@@ -163,7 +169,7 @@ class ProfileWorkerStore(SQLBaseStore):
users: List[UserID],
active: bool,
hide: bool,
batchnum: int,
batchnum: Optional[int],
) -> None:
"""Given a set of users, set active and hidden flags on them.
@@ -179,13 +185,13 @@ class ProfileWorkerStore(SQLBaseStore):
user_localparts = [(user.localpart,) for user in users]
# Generate list of value tuples for each user
value_names = ("active", "batch")
value_names = ["active", "batch"]
values = [(int(active), batchnum) for _ in user_localparts] # type: List[Tuple]
if not active and not hide:
# we are deactivating for real (not in hide mode)
# so clear the profile information
value_names += ("avatar_url", "displayname")
value_names += ["avatar_url", "displayname"]
values = [v + (None, None) for v in values]
return await self.db_pool.runInteraction(
@@ -293,17 +299,6 @@ class ProfileWorkerStore(SQLBaseStore):
class ProfileStore(ProfileWorkerStore):
def __init__(self, database, db_conn, hs):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
"profile_replication_status_host_index",
index_name="profile_replication_status_idx",
table="profile_replication_status",
columns=["host"],
unique=True,
)
async def add_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> None:

View File

@@ -531,29 +531,6 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(state[("org.matrix.test", "")].state_key, "")
self.assertEqual(state[("org.matrix.test", "")].content, {})
def test_get_room_state(self):
"""Tests that a module can retrieve the state of a room through the module API."""
user_id = self.register_user("peter", "hackme")
tok = self.login("peter", "hackme")
# Create a room and send some custom state in it.
room_id = self.helper.create_room_as(tok=tok)
self.helper.send_state(room_id, "org.matrix.test", {}, tok=tok)
# Check that the module API can successfully fetch state for the room.
state = self.get_success(
defer.ensureDeferred(self.module_api.get_room_state(room_id))
)
# Check that a few standard events are in the returned state.
self.assertIn((EventTypes.Create, ""), state)
self.assertIn((EventTypes.Member, user_id), state)
# Check that our custom state event is in the returned state.
self.assertEqual(state[("org.matrix.test", "")].sender, user_id)
self.assertEqual(state[("org.matrix.test", "")].state_key, "")
self.assertEqual(state[("org.matrix.test", "")].content, {})
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup"""