diff --git a/synapse/storage/databases/main/thread_subscriptions.py b/synapse/storage/databases/main/thread_subscriptions.py index 7c48e5761c..52b8a587a8 100644 --- a/synapse/storage/databases/main/thread_subscriptions.py +++ b/synapse/storage/databases/main/thread_subscriptions.py @@ -553,14 +553,14 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): limit: The maximum number of rows to return Returns: - A list of (stream_id, room_id, thread_root_event_id) tuples. + A list of (stream_id, room_id, thread_root_event_id, subscribed, automatic) tuples. """ def get_updated_thread_subscriptions_for_user_txn( txn: LoggingTransaction, - ) -> List[Tuple[int, str, str]]: + ) -> List[Tuple[int, str, str, bool, Optional[bool]]]: sql = """ - SELECT stream_id, room_id, event_id + SELECT stream_id, room_id, event_id, subscribed, automatic FROM thread_subscriptions WHERE user_id = ? AND ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC @@ -568,7 +568,17 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore): """ txn.execute(sql, (user_id, from_id, to_id, limit)) - return [(row[0], row[1], row[2]) for row in txn] + return [ + ( + stream_id, + room_id, + event_id, + # SQLite integer to boolean conversions + bool(subscribed), + bool(automatic) if subscribed else None, + ) + for (stream_id, room_id, event_id, subscribed, automatic) in txn + ] return await self.db_pool.runInteraction( "get_updated_thread_subscriptions_for_user", diff --git a/tests/storage/test_thread_subscriptions.py b/tests/storage/test_thread_subscriptions.py index 2a5c440cf4..fd05754509 100644 --- a/tests/storage/test_thread_subscriptions.py +++ b/tests/storage/test_thread_subscriptions.py @@ -196,12 +196,12 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): limit=50, ) ) - min_id = min(id for (id, _, _) in subscriptions) + min_id = min(id for (id, _, _, _, _) in subscriptions) self.assertEqual( subscriptions, [ - (min_id, self.room_id, self.thread_root_id), - (min_id + 1, self.room_id, self.other_thread_root_id), + (min_id, self.room_id, self.thread_root_id, True, True), + (min_id + 1, self.room_id, self.other_thread_root_id, True, False), ], ) @@ -284,7 +284,9 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): self.user_id, from_id=0, to_id=stream_id2, limit=10 ) ) - self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)]) + self.assertEqual( + updates, [(stream_id1, self.room_id, self.thread_root_id, True, True)] + ) # Get updates for other user updates = self.get_success( @@ -293,7 +295,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase): ) ) self.assertEqual( - updates, [(stream_id2, self.room_id, self.other_thread_root_id)] + updates, [(stream_id2, self.room_id, self.other_thread_root_id, True, True)] ) def test_should_skip_autosubscription_after_unsubscription(self) -> None: