1
0

Track device IDs for pushers

This commit is contained in:
Brendan Abolivier
2022-09-16 18:02:20 +01:00
parent 8ae42ab8fa
commit 2c05a30952
6 changed files with 139 additions and 5 deletions

View File

@@ -117,6 +117,7 @@ class PusherConfig:
last_success: Optional[int]
failing_since: Optional[int]
enabled: bool
device_id: Optional[str]
def as_dict(self) -> Dict[str, Any]:
"""Information that can be retrieved about a pusher after creation."""
@@ -130,6 +131,7 @@ class PusherConfig:
"profile_tag": self.profile_tag,
"pushkey": self.pushkey,
"enabled": self.enabled,
"device_id": self.device_id,
}

View File

@@ -107,6 +107,7 @@ class PusherPool:
data: JsonDict,
profile_tag: str = "",
enabled: bool = True,
device_id: Optional[str] = None,
) -> Optional[Pusher]:
"""Creates a new pusher and adds it to the pool
@@ -149,18 +150,20 @@ class PusherPool:
last_success=None,
failing_since=None,
enabled=enabled,
device_id=device_id,
)
)
# Before we actually persist the pusher, we check if the user already has one
# for this app ID and pushkey. If so, we want to keep the access token in place,
# since this could be one device modifying (e.g. enabling/disabling) another
# device's pusher.
# this app ID and pushkey. If so, we want to keep the access token and device ID
# in place, since this could be one device modifying (e.g. enabling/disabling)
# another device's pusher.
existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
user_id, app_id, pushkey
)
if existing_config:
access_token = existing_config.access_token
device_id = existing_config.device_id
await self.store.add_pusher(
user_id=user_id,
@@ -176,6 +179,7 @@ class PusherPool:
last_stream_ordering=last_stream_ordering,
profile_tag=profile_tag,
enabled=enabled,
device_id=device_id,
)
pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id)

View File

@@ -57,7 +57,9 @@ class PushersRestServlet(RestServlet):
for pusher in pusher_dicts:
if self._msc3881_enabled:
pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
pusher["org.matrix.msc3881.device_id"] = pusher["device_id"]
del pusher["enabled"]
del pusher["device_id"]
return 200, {"pushers": pusher_dicts}
@@ -134,6 +136,7 @@ class PushersSetRestServlet(RestServlet):
data=content["data"],
profile_tag=content.get("profile_tag", ""),
enabled=enabled,
device_id=requester.device_id,
)
except PusherConfigException as pce:
raise SynapseError(

View File

@@ -124,7 +124,7 @@ class PusherWorkerStore(SQLBaseStore):
id, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, ts, lang, data,
last_stream_ordering, last_success, failing_since,
COALESCE(enabled, TRUE) AS enabled
COALESCE(enabled, TRUE) AS enabled, device_id
FROM pushers
"""
@@ -477,7 +477,72 @@ class PusherWorkerStore(SQLBaseStore):
return number_deleted
class PusherStore(PusherWorkerStore):
class PusherBackgroundUpdatesStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
"set_device_id_for_pushers", self._set_device_id_for_pushers
)
async def _set_device_id_for_pushers(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to populate the device_id column of the pushers table."""
last_pusher_id = progress.get("pusher_id", 0)
def set_device_id_for_pushers_txn(txn: LoggingTransaction) -> int:
txn.execute(
"""
SELECT p.id, at.device_id
FROM pushers AS p
INNER JOIN access_tokens AS at
ON p.access_token = at.id
WHERE
p.access_token IS NOT NULL
AND at.device_id IS NOT NULL
AND p.id > ?
ORDER BY p.id
LIMIT ?
""",
(last_pusher_id, batch_size),
)
rows = self.db_pool.cursor_to_dict(txn)
if len(rows) == 0:
return 0
self.db_pool.simple_update_many_txn(
txn=txn,
table="pushers",
key_names=("id",),
key_values=[(row["id"],) for row in rows],
value_names=("device_id",),
value_values=[(row["device_id"],) for row in rows],
)
self.db_pool.updates._background_update_progress_txn(
txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["id"]}
)
return len(rows)
nb_processed = await self.db_pool.runInteraction(
"set_device_id_for_pushers", set_device_id_for_pushers_txn
)
if nb_processed < batch_size:
await self.db_pool.updates._end_background_update("set_device_id_for_pushers")
return nb_processed
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
@@ -496,6 +561,7 @@ class PusherStore(PusherWorkerStore):
last_stream_ordering: int,
profile_tag: str = "",
enabled: bool = True,
device_id: Optional[str] = None,
) -> None:
async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
@@ -515,6 +581,7 @@ class PusherStore(PusherWorkerStore):
"profile_tag": profile_tag,
"id": stream_id,
"enabled": enabled,
"device_id": device_id,
},
desc="add_pusher",
lock=False,

View File

@@ -0,0 +1,16 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE pushers ADD COLUMN device_id TEXT;

View File

@@ -22,6 +22,7 @@ from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfig, PusherConfigException
from synapse.rest.client import login, push_rule, pusher, receipts, room
from synapse.server import HomeServer
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import JsonDict
from synapse.util import Clock
@@ -913,3 +914,44 @@ class HTTPPusherTests(HomeserverTestCase):
# it didn't change.
self.assertEqual(len(pushers), 1)
self.assertEqual(pushers[0].access_token, token_id)
@override_config({"experimental_features": {"msc3881_enabled": True}})
def test_device_id(self) -> None:
"""Tests that a pusher created with a given device ID shows that device ID in
GET /pushers requests.
"""
self.register_user("user", "pass")
access_token = self.login("user", "pass")
# We create the pusher with an HTTP request rather than with
# _make_user_with_pusher so that we can test the device ID is correctly set when
# creating a pusher via an API call.
self.make_request(
method="POST",
path="/pushers/set",
content={
"kind": "http",
"app_id": "m.http",
"app_display_name": "HTTP Push Notifications",
"device_display_name": "pushy push",
"pushkey": "a@example.com",
"lang": "en",
"data": {"url": "http://example.com/_matrix/push/v1/notify"},
},
access_token=access_token,
)
# Look up the user info for the access token so we can compare the device ID.
lookup_result: TokenLookupResult = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
# Get the user's devices and check it has the correct device ID.
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
self.assertEqual(
channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"],
lookup_result.device_id,
)