Fix tests on postgresql (#3740)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -11,89 +12,91 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import tempfile
|
||||
|
||||
from mock import Mock, NonCallableMock
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.defer import Deferred
|
||||
import attr
|
||||
|
||||
from synapse.replication.tcp.client import (
|
||||
ReplicationClientFactory,
|
||||
ReplicationClientHandler,
|
||||
)
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
from synapse.util.logcontext import PreserveLoggingContext, make_deferred_yieldable
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
|
||||
class TestReplicationClientHandler(ReplicationClientHandler):
|
||||
"""Overrides on_rdata so that we can wait for it to happen"""
|
||||
class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
def __init__(self, store):
|
||||
super(TestReplicationClientHandler, self).__init__(store)
|
||||
self._rdata_awaiters = []
|
||||
|
||||
def await_replication(self):
|
||||
d = Deferred()
|
||||
self._rdata_awaiters.append(d)
|
||||
return make_deferred_yieldable(d)
|
||||
|
||||
def on_rdata(self, stream_name, token, rows):
|
||||
awaiters = self._rdata_awaiters
|
||||
self._rdata_awaiters = []
|
||||
super(TestReplicationClientHandler, self).on_rdata(stream_name, token, rows)
|
||||
with PreserveLoggingContext():
|
||||
for a in awaiters:
|
||||
a.callback(None)
|
||||
|
||||
|
||||
class BaseSlavedStoreTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def setUp(self):
|
||||
self.hs = yield setup_test_homeserver(
|
||||
self.addCleanup,
|
||||
hs = self.setup_test_homeserver(
|
||||
"blue",
|
||||
http_client=None,
|
||||
federation_client=Mock(),
|
||||
ratelimiter=NonCallableMock(spec_set=["send_message"]),
|
||||
)
|
||||
self.hs.get_ratelimiter().send_message.return_value = (True, 0)
|
||||
|
||||
hs.get_ratelimiter().send_message.return_value = (True, 0)
|
||||
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
|
||||
self.master_store = self.hs.get_datastore()
|
||||
self.slaved_store = self.STORE_TYPE(self.hs.get_db_conn(), self.hs)
|
||||
self.event_id = 0
|
||||
|
||||
server_factory = ReplicationStreamProtocolFactory(self.hs)
|
||||
# XXX: mktemp is unsafe and should never be used. but we're just a test.
|
||||
path = tempfile.mktemp(prefix="base_slaved_store_test_case_socket")
|
||||
listener = reactor.listenUNIX(path, server_factory)
|
||||
self.addCleanup(listener.stopListening)
|
||||
self.streamer = server_factory.streamer
|
||||
|
||||
self.replication_handler = TestReplicationClientHandler(self.slaved_store)
|
||||
self.replication_handler = ReplicationClientHandler(self.slaved_store)
|
||||
client_factory = ReplicationClientFactory(
|
||||
self.hs, "client_name", self.replication_handler
|
||||
)
|
||||
client_connector = reactor.connectUNIX(path, client_factory)
|
||||
self.addCleanup(client_factory.stopTrying)
|
||||
self.addCleanup(client_connector.disconnect)
|
||||
|
||||
server = server_factory.buildProtocol(None)
|
||||
client = client_factory.buildProtocol(None)
|
||||
|
||||
@attr.s
|
||||
class FakeTransport(object):
|
||||
|
||||
other = attr.ib()
|
||||
disconnecting = False
|
||||
buffer = attr.ib(default=b'')
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
|
||||
self.producer = producer
|
||||
|
||||
def _produce():
|
||||
self.producer.resumeProducing()
|
||||
reactor.callLater(0.1, _produce)
|
||||
|
||||
reactor.callLater(0.0, _produce)
|
||||
|
||||
def write(self, byt):
|
||||
self.buffer = self.buffer + byt
|
||||
|
||||
if getattr(self.other, "transport") is not None:
|
||||
self.other.dataReceived(self.buffer)
|
||||
self.buffer = b""
|
||||
|
||||
def writeSequence(self, seq):
|
||||
for x in seq:
|
||||
self.write(x)
|
||||
|
||||
client.makeConnection(FakeTransport(server))
|
||||
server.makeConnection(FakeTransport(client))
|
||||
|
||||
def replicate(self):
|
||||
"""Tell the master side of replication that something has happened, and then
|
||||
wait for the replication to occur.
|
||||
"""
|
||||
# xxx: should we be more specific in what we wait for?
|
||||
d = self.replication_handler.await_replication()
|
||||
self.streamer.on_notifier_poke()
|
||||
return d
|
||||
self.pump(0.1)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check(self, method, args, expected_result=None):
|
||||
master_result = yield getattr(self.master_store, method)(*args)
|
||||
slaved_result = yield getattr(self.slaved_store, method)(*args)
|
||||
master_result = self.get_success(getattr(self.master_store, method)(*args))
|
||||
slaved_result = self.get_success(getattr(self.slaved_store, method)(*args))
|
||||
if expected_result is not None:
|
||||
self.assertEqual(master_result, expected_result)
|
||||
self.assertEqual(slaved_result, expected_result)
|
||||
|
||||
@@ -12,9 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
|
||||
|
||||
from ._base import BaseSlavedStoreTestCase
|
||||
@@ -27,16 +24,19 @@ class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
|
||||
|
||||
STORE_TYPE = SlavedAccountDataStore
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_user_account_data(self):
|
||||
yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
self.get_success(
|
||||
self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
|
||||
)
|
||||
self.replicate()
|
||||
self.check(
|
||||
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 1}
|
||||
)
|
||||
|
||||
yield self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
self.get_success(
|
||||
self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
|
||||
)
|
||||
self.replicate()
|
||||
self.check(
|
||||
"get_global_account_data_by_type_for_user", [TYPE, USER_ID], {"a": 2}
|
||||
)
|
||||
|
||||
@@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.events import FrozenEvent, _EventInternalMetadata
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
@@ -55,70 +53,66 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
def tearDown(self):
|
||||
[unpatch() for unpatch in self.unpatches]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_latest_event_ids_in_room(self):
|
||||
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
yield self.replicate()
|
||||
yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
|
||||
create = self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.replicate()
|
||||
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [create.event_id])
|
||||
|
||||
join = yield self.persist(
|
||||
join = self.persist(
|
||||
type="m.room.member",
|
||||
key=USER_ID,
|
||||
membership="join",
|
||||
prev_events=[(create.event_id, {})],
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
|
||||
self.replicate()
|
||||
self.check("get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_redactions(self):
|
||||
yield self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
|
||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.persist(type="m.room.member", key=USER_ID, membership="join")
|
||||
|
||||
msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
|
||||
yield self.replicate()
|
||||
yield self.check("get_event", [msg.event_id], msg)
|
||||
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
|
||||
self.replicate()
|
||||
self.check("get_event", [msg.event_id], msg)
|
||||
|
||||
redaction = yield self.persist(type="m.room.redaction", redacts=msg.event_id)
|
||||
yield self.replicate()
|
||||
redaction = self.persist(type="m.room.redaction", redacts=msg.event_id)
|
||||
self.replicate()
|
||||
|
||||
msg_dict = msg.get_dict()
|
||||
msg_dict["content"] = {}
|
||||
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
|
||||
msg_dict["unsigned"]["redacted_because"] = redaction
|
||||
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
|
||||
yield self.check("get_event", [msg.event_id], redacted)
|
||||
self.check("get_event", [msg.event_id], redacted)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_backfilled_redactions(self):
|
||||
yield self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
yield self.persist(type="m.room.member", key=USER_ID, membership="join")
|
||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.persist(type="m.room.member", key=USER_ID, membership="join")
|
||||
|
||||
msg = yield self.persist(type="m.room.message", msgtype="m.text", body="Hello")
|
||||
yield self.replicate()
|
||||
yield self.check("get_event", [msg.event_id], msg)
|
||||
msg = self.persist(type="m.room.message", msgtype="m.text", body="Hello")
|
||||
self.replicate()
|
||||
self.check("get_event", [msg.event_id], msg)
|
||||
|
||||
redaction = yield self.persist(
|
||||
redaction = self.persist(
|
||||
type="m.room.redaction", redacts=msg.event_id, backfill=True
|
||||
)
|
||||
yield self.replicate()
|
||||
self.replicate()
|
||||
|
||||
msg_dict = msg.get_dict()
|
||||
msg_dict["content"] = {}
|
||||
msg_dict["unsigned"]["redacted_by"] = redaction.event_id
|
||||
msg_dict["unsigned"]["redacted_because"] = redaction
|
||||
redacted = FrozenEvent(msg_dict, msg.internal_metadata.get_dict())
|
||||
yield self.check("get_event", [msg.event_id], redacted)
|
||||
self.check("get_event", [msg.event_id], redacted)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_invites(self):
|
||||
yield self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
yield self.check("get_invited_rooms_for_user", [USER_ID_2], [])
|
||||
event = yield self.persist(
|
||||
type="m.room.member", key=USER_ID_2, membership="invite"
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.check("get_invited_rooms_for_user", [USER_ID_2], [])
|
||||
event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite")
|
||||
|
||||
self.replicate()
|
||||
|
||||
self.check(
|
||||
"get_invited_rooms_for_user",
|
||||
[USER_ID_2],
|
||||
[
|
||||
@@ -132,37 +126,34 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_push_actions_for_user(self):
|
||||
yield self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
yield self.persist(type="m.room.join", key=USER_ID, membership="join")
|
||||
yield self.persist(
|
||||
self.persist(type="m.room.create", key="", creator=USER_ID)
|
||||
self.persist(type="m.room.join", key=USER_ID, membership="join")
|
||||
self.persist(
|
||||
type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
|
||||
)
|
||||
event1 = yield self.persist(
|
||||
type="m.room.message", msgtype="m.text", body="hello"
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
|
||||
self.replicate()
|
||||
self.check(
|
||||
"get_unread_event_push_actions_by_room_for_user",
|
||||
[ROOM_ID, USER_ID_2, event1.event_id],
|
||||
{"highlight_count": 0, "notify_count": 0},
|
||||
)
|
||||
|
||||
yield self.persist(
|
||||
self.persist(
|
||||
type="m.room.message",
|
||||
msgtype="m.text",
|
||||
body="world",
|
||||
push_actions=[(USER_ID_2, ["notify"])],
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
self.replicate()
|
||||
self.check(
|
||||
"get_unread_event_push_actions_by_room_for_user",
|
||||
[ROOM_ID, USER_ID_2, event1.event_id],
|
||||
{"highlight_count": 0, "notify_count": 1},
|
||||
)
|
||||
|
||||
yield self.persist(
|
||||
self.persist(
|
||||
type="m.room.message",
|
||||
msgtype="m.text",
|
||||
body="world",
|
||||
@@ -170,8 +161,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
(USER_ID_2, ["notify", {"set_tweak": "highlight", "value": True}])
|
||||
],
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
self.replicate()
|
||||
self.check(
|
||||
"get_unread_event_push_actions_by_room_for_user",
|
||||
[ROOM_ID, USER_ID_2, event1.event_id],
|
||||
{"highlight_count": 1, "notify_count": 2},
|
||||
@@ -179,7 +170,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
|
||||
event_id = 0
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def persist(
|
||||
self,
|
||||
sender=USER_ID,
|
||||
@@ -206,8 +196,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
depth = self.event_id
|
||||
|
||||
if not prev_events:
|
||||
latest_event_ids = yield self.master_store.get_latest_event_ids_in_room(
|
||||
room_id
|
||||
latest_event_ids = self.get_success(
|
||||
self.master_store.get_latest_event_ids_in_room(room_id)
|
||||
)
|
||||
prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
|
||||
|
||||
@@ -240,19 +230,23 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
)
|
||||
else:
|
||||
state_handler = self.hs.get_state_handler()
|
||||
context = yield state_handler.compute_event_context(event)
|
||||
context = self.get_success(state_handler.compute_event_context(event))
|
||||
|
||||
yield self.master_store.add_push_actions_to_staging(
|
||||
self.master_store.add_push_actions_to_staging(
|
||||
event.event_id, {user_id: actions for user_id, actions in push_actions}
|
||||
)
|
||||
|
||||
ordering = None
|
||||
if backfill:
|
||||
yield self.master_store.persist_events([(event, context)], backfilled=True)
|
||||
self.get_success(
|
||||
self.master_store.persist_events([(event, context)], backfilled=True)
|
||||
)
|
||||
else:
|
||||
ordering, _ = yield self.master_store.persist_event(event, context)
|
||||
ordering, _ = self.get_success(
|
||||
self.master_store.persist_event(event, context)
|
||||
)
|
||||
|
||||
if ordering:
|
||||
event.internal_metadata.stream_ordering = ordering
|
||||
|
||||
defer.returnValue(event)
|
||||
return event
|
||||
|
||||
@@ -12,8 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
|
||||
from ._base import BaseSlavedStoreTestCase
|
||||
@@ -27,13 +25,10 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
|
||||
|
||||
STORE_TYPE = SlavedReceiptsStore
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_receipt(self):
|
||||
yield self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
|
||||
yield self.master_store.insert_receipt(
|
||||
ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}
|
||||
)
|
||||
yield self.replicate()
|
||||
yield self.check(
|
||||
"get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}
|
||||
self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
|
||||
self.get_success(
|
||||
self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {})
|
||||
)
|
||||
self.replicate()
|
||||
self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID})
|
||||
|
||||
Reference in New Issue
Block a user