From f59419377d4ef34952ee7ab898668ca94cd5d6d4 Mon Sep 17 00:00:00 2001 From: Devon Hudson Date: Sun, 9 Nov 2025 09:35:11 -0700 Subject: [PATCH] Refactor for clarity --- synapse/handlers/sliding_sync/extensions.py | 79 +++++++++++++-------- 1 file changed, 49 insertions(+), 30 deletions(-) diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index de50560583..c2dffa4bf6 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -1044,6 +1044,43 @@ class SlidingSyncExtensionHandler: threads_in_timeline.add(thread_id) return threads_in_timeline + def _merge_prev_batch_token( + self, + current_token: Optional[StreamToken], + new_token: Optional[StreamToken], + ) -> Optional[StreamToken]: + """Merge two prev_batch tokens, taking the maximum (latest) for backwards pagination. + + Args: + current_token: The current prev_batch token (may be None) + new_token: The new prev_batch token to merge (may be None) + + Returns: + The merged token (maximum of the two, or None if both are None) + """ + if new_token is None: + return current_token + if current_token is None: + return new_token + # Take the maximum (latest) prev_batch token for backwards pagination + if new_token.room_key.stream > current_token.room_key.stream: + return new_token + return current_token + + def _merge_thread_updates( + self, + target: Dict[str, List[ThreadUpdateInfo]], + source: Dict[str, List[ThreadUpdateInfo]], + ) -> None: + """Merge thread updates from source into target. + + Args: + target: The target dict to merge into (modified in place) + source: The source dict to merge from + """ + for thread_id, updates in source.items(): + target.setdefault(thread_id, []).extend(updates) + async def get_threads_extension_response( self, sync_config: SlidingSyncConfig, @@ -1136,21 +1173,12 @@ class SlidingSyncExtensionHandler: remaining_limit -= num_updates # Merge results - for thread_id, updates in room_thread_updates.items(): - all_thread_updates.setdefault(thread_id, []).extend(updates) + self._merge_thread_updates(all_thread_updates, room_thread_updates) - # If any room has a prev_batch, we should set the global prev_batch. - # We use the maximum (latest) prev_batch token for backwards pagination. - if room_prev_batch is not None: - if prev_batch_token is None: - prev_batch_token = room_prev_batch - else: - # Take the maximum (latest) prev_batch token for backwards pagination - if ( - room_prev_batch.room_key.stream - > prev_batch_token.room_key.stream - ): - prev_batch_token = room_prev_batch + # Merge prev_batch tokens + prev_batch_token = self._merge_prev_batch_token( + prev_batch_token, room_prev_batch + ) # Query for rooms where the user is joined, invited, or knocking, using the # normal to_token as the upper bound. @@ -1167,20 +1195,12 @@ class SlidingSyncExtensionHandler: ) # Merge results - for thread_id, updates in other_thread_updates.items(): - all_thread_updates.setdefault(thread_id, []).extend(updates) + self._merge_thread_updates(all_thread_updates, other_thread_updates) # Merge prev_batch tokens - if other_prev_batch is not None: - if prev_batch_token is None: - prev_batch_token = other_prev_batch - else: - # Take the maximum (latest) prev_batch token for backwards pagination - if ( - other_prev_batch.room_key.stream - > prev_batch_token.room_key.stream - ): - prev_batch_token = other_prev_batch + prev_batch_token = self._merge_prev_batch_token( + prev_batch_token, other_prev_batch + ) # Early return: no thread updates found if len(all_thread_updates) == 0: @@ -1210,10 +1230,9 @@ class SlidingSyncExtensionHandler: if not filtered_updates: return None - # Sort updates for each thread by stream_ordering DESC to ensure updates[0] is the latest. - # This is critical because the prev_batch token generation below assumes DESC order. - for updates in filtered_updates.values(): - updates.sort(key=lambda u: u.stream_ordering, reverse=True) + # Note: Updates are already sorted by stream_ordering DESC from the database query, + # and filter_events_for_client preserves order, so updates[0] is guaranteed to be + # the latest event for each thread. # Optionally fetch thread root events and their bundled aggregations thread_root_event_map = {}