1
0

Add subscribed and automatic to get_updated_thread_subscriptions_for_user

This commit is contained in:
Olivier 'reivilibre
2025-07-18 11:25:43 +01:00
parent 748316c14a
commit 4dcd12b8d1
2 changed files with 21 additions and 9 deletions

View File

@@ -553,14 +553,14 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
limit: The maximum number of rows to return
Returns:
A list of (stream_id, room_id, thread_root_event_id) tuples.
A list of (stream_id, room_id, thread_root_event_id, subscribed, automatic) tuples.
"""
def get_updated_thread_subscriptions_for_user_txn(
txn: LoggingTransaction,
) -> List[Tuple[int, str, str]]:
) -> List[Tuple[int, str, str, bool, Optional[bool]]]:
sql = """
SELECT stream_id, room_id, event_id
SELECT stream_id, room_id, event_id, subscribed, automatic
FROM thread_subscriptions
WHERE user_id = ? AND ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
@@ -568,7 +568,17 @@ class ThreadSubscriptionsWorkerStore(CacheInvalidationWorkerStore):
"""
txn.execute(sql, (user_id, from_id, to_id, limit))
return [(row[0], row[1], row[2]) for row in txn]
return [
(
stream_id,
room_id,
event_id,
# SQLite integer to boolean conversions
bool(subscribed),
bool(automatic) if subscribed else None,
)
for (stream_id, room_id, event_id, subscribed, automatic) in txn
]
return await self.db_pool.runInteraction(
"get_updated_thread_subscriptions_for_user",

View File

@@ -196,12 +196,12 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
limit=50,
)
)
min_id = min(id for (id, _, _) in subscriptions)
min_id = min(id for (id, _, _, _, _) in subscriptions)
self.assertEqual(
subscriptions,
[
(min_id, self.room_id, self.thread_root_id),
(min_id + 1, self.room_id, self.other_thread_root_id),
(min_id, self.room_id, self.thread_root_id, True, True),
(min_id + 1, self.room_id, self.other_thread_root_id, True, False),
],
)
@@ -284,7 +284,9 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
self.user_id, from_id=0, to_id=stream_id2, limit=10
)
)
self.assertEqual(updates, [(stream_id1, self.room_id, self.thread_root_id)])
self.assertEqual(
updates, [(stream_id1, self.room_id, self.thread_root_id, True, True)]
)
# Get updates for other user
updates = self.get_success(
@@ -293,7 +295,7 @@ class ThreadSubscriptionsTestCase(unittest.HomeserverTestCase):
)
)
self.assertEqual(
updates, [(stream_id2, self.room_id, self.other_thread_root_id)]
updates, [(stream_id2, self.room_id, self.other_thread_root_id, True, True)]
)
def test_should_skip_autosubscription_after_unsubscription(self) -> None: