1
0

Compare commits

...

7 Commits

Author SHA1 Message Date
Erik Johnston
f5817281f8 Fixup2 2022-12-15 14:05:59 +00:00
Erik Johnston
87406aa5d3 Fixup 2022-12-15 13:20:20 +00:00
Erik Johnston
6842974391 Fixup 2022-12-15 13:15:47 +00:00
Erik Johnston
c93ef61fa3 WIP Rust HTTP for federation 2022-12-14 11:02:16 +00:00
David Robertson
e2a1adbf5d Allow selecting "prejoin" events by state keys (#14642)
* Declare new config

* Parse new config

* Read new config

* Don't use trial/our TestCase where it's not needed

Before:

```
$ time trial tests/events/test_utils.py > /dev/null

real	0m2.277s
user	0m2.186s
sys	0m0.083s
```

After:
```
$ time trial tests/events/test_utils.py > /dev/null

real	0m0.566s
user	0m0.508s
sys	0m0.056s
```

* Helper to upsert to event fields

without exceeding size limits.

* Use helper when adding invite/knock state

Now that we allow admins to include events in prejoin room state with
arbitrary state keys, be a good Matrix citizen and ensure they don't
accidentally create an oversized event.

* Changelog

* Move StateFilter tests

should have done this in #14668

* Add extra methods to StateFilter

* Use StateFilter

* Ensure test file enforces typed defs; alphabetise

* Workaround surprising get_current_state_ids

* Whoops, fix mypy
2022-12-13 00:54:46 +00:00
David Robertson
3d87847ecc Enable --warn-redundant-casts option in mypy (#14671)
* Enable `--warn-redundant-casts` option in mypy

Doesn't do much but helps me sleep better at night.

* Changelog

* Fix name of the ignore

* Fix one more missed cast

Not sure why I didn't see this one locally, maybe I needed a poetry update

* Remove old comment

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>

Co-authored-by: Patrick Cloke <clokep@users.noreply.github.com>
2022-12-12 21:25:07 +00:00
Sean Quah
7982891794 Fix missing cache invalidation in application service code (#14670)
#11915 introduced the `@cached` `is_interested_in_room` method in
Synapse 1.55.0, which depends upon `get_aliases_for_room`. Add a missing
cache invalidation callback so that the `is_interested_in_room` cache is
invalidated when `get_aliases_for_room` is invalidated.

#13787 made `get_rooms_for_user` `@cached`. Add a missing cache
invalidation callback so that the `is_interested_in_presence` cache is
invalidated when `get_rooms_for_user` is invalidated.

Signed-off-by: Sean Quah <seanq@matrix.org>
2022-12-12 18:13:43 +00:00
29 changed files with 2803 additions and 749 deletions

1114
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1 @@
Allow selecting "prejoin" events by state keys in addition to event types.

1
changelog.d/14670.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix bugs introduced in 1.55.0 and 1.69.0 where application services would not be notified of events in the correct rooms, due to stale caches.

1
changelog.d/14671.misc Normal file
View File

@@ -0,0 +1 @@
Improve type hints.

View File

@@ -2501,32 +2501,53 @@ Config settings related to the client/server API
---
### `room_prejoin_state`
Controls for the state that is shared with users who receive an invite
to a room. By default, the following state event types are shared with users who
receive invites to the room:
- m.room.join_rules
- m.room.canonical_alias
- m.room.avatar
- m.room.encryption
- m.room.name
- m.room.create
- m.room.topic
This setting controls the state that is shared with users upon receiving an
invite to a room, or in reply to a knock on a room. By default, the following
state events are shared with users:
- `m.room.join_rules`
- `m.room.canonical_alias`
- `m.room.avatar`
- `m.room.encryption`
- `m.room.name`
- `m.room.create`
- `m.room.topic`
To change the default behavior, use the following sub-options:
* `disable_default_event_types`: set to true to disable the above defaults. If this
is enabled, only the event types listed in `additional_event_types` are shared.
Defaults to false.
* `additional_event_types`: Additional state event types to share with users when they are invited
to a room. By default, this list is empty (so only the default event types are shared).
* `disable_default_event_types`: boolean. Set to `true` to disable the above
defaults. If this is enabled, only the event types listed in
`additional_event_types` are shared. Defaults to `false`.
* `additional_event_types`: A list of additional state events to include in the
events to be shared. By default, this list is empty (so only the default event
types are shared).
Each entry in this list should be either a single string or a list of two
strings.
* A standalone string `t` represents all events with type `t` (i.e.
with no restrictions on state keys).
* A pair of strings `[t, s]` represents a single event with type `t` and
state key `s`. The same type can appear in two entries with different state
keys: in this situation, both state keys are included in prejoin state.
Example configuration:
```yaml
room_prejoin_state:
disable_default_event_types: true
disable_default_event_types: false
additional_event_types:
- org.example.custom.event.type
- m.room.join_rules
# Share all events of type `org.example.custom.event.typeA`
- org.example.custom.event.typeA
# Share only events of type `org.example.custom.event.typeB` whose
# state_key is "foo"
- ["org.example.custom.event.typeB", "foo"]
# Share only events of type `org.example.custom.event.typeC` whose
# state_key is "bar" or "baz"
- ["org.example.custom.event.typeC", "bar"]
- ["org.example.custom.event.typeC", "baz"]
```
*Changed in Synapse 1.74:* admins can filter the events in prejoin state based
on their state key.
---
### `track_puppeted_user_ips`

View File

@@ -12,6 +12,7 @@ local_partial_types = True
no_implicit_optional = True
disallow_untyped_defs = True
strict_equality = True
warn_redundant_casts = True
files =
docker/,
@@ -88,6 +89,12 @@ disallow_untyped_defs = False
[mypy-tests.*]
disallow_untyped_defs = False
[mypy-tests.config.test_api]
disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client]
disallow_untyped_defs = True
[mypy-tests.handlers.test_sso]
disallow_untyped_defs = True
@@ -100,7 +107,7 @@ disallow_untyped_defs = True
[mypy-tests.push.test_bulk_push_rule_evaluator]
disallow_untyped_defs = True
[mypy-tests.test_server]
[mypy-tests.rest.*]
disallow_untyped_defs = True
[mypy-tests.state.test_profile]
@@ -109,10 +116,10 @@ disallow_untyped_defs = True
[mypy-tests.storage.*]
disallow_untyped_defs = True
[mypy-tests.rest.*]
[mypy-tests.test_server]
disallow_untyped_defs = True
[mypy-tests.federation.transport.test_client]
[mypy-tests.types.*]
disallow_untyped_defs = True
[mypy-tests.util.caches.*]

View File

@@ -21,14 +21,25 @@ name = "synapse.synapse_rust"
[dependencies]
anyhow = "1.0.63"
env_logger = "0.10.0"
futures = "0.3.25"
futures-util = "0.3.25"
http = "0.2.8"
hyper = { version = "0.14.23", features = ["client", "http1", "http2", "runtime", "server", "full"] }
hyper-tls = "0.5.0"
lazy_static = "1.4.0"
log = "0.4.17"
native-tls = "0.2.11"
pyo3 = { version = "0.17.1", features = ["extension-module", "macros", "anyhow", "abi3", "abi3-py37"] }
pyo3-asyncio = { version = "0.17.0", features = ["tokio", "tokio-runtime"] }
pyo3-log = "0.7.0"
pythonize = "0.17.0"
regex = "1.6.0"
serde = { version = "1.0.144", features = ["derive"] }
serde_json = "1.0.85"
tokio = "1.23.0"
tokio-native-tls = "0.3.0"
trust-dns-resolver = "0.22.0"
[build-dependencies]
blake2 = "0.10.4"

158
rust/src/http/mod.rs Normal file
View File

@@ -0,0 +1,158 @@
use std::collections::HashMap;
use anyhow::Error;
use http::{Request, Uri};
use hyper::Body;
use log::info;
use pyo3::{
pyclass, pymethods,
types::{PyBytes, PyModule},
IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject,
};
use self::resolver::{MatrixConnector, MatrixResolver};
pub mod resolver;
/// Called when registering modules with python.
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let child_module = PyModule::new(py, "http")?;
child_module.add_class::<HttpClient>()?;
child_module.add_class::<MatrixResponse>()?;
m.add_submodule(child_module)?;
// We need to manually add the module to sys.modules to make `from
// synapse.synapse_rust import push` work.
py.import("sys")?
.getattr("modules")?
.set_item("synapse.synapse_rust.http", child_module)?;
Ok(())
}
#[derive(Clone, Debug)]
pub struct Bytes(pub Vec<u8>);
impl ToPyObject for Bytes {
fn to_object(&self, py: Python<'_>) -> pyo3::PyObject {
PyBytes::new(py, &self.0).into_py(py)
}
}
impl IntoPy<PyObject> for Bytes {
fn into_py(self, py: Python<'_>) -> PyObject {
self.to_object(py)
}
}
#[derive(Debug)]
#[pyclass]
pub struct MatrixResponse {
#[pyo3(get)]
pub code: u16,
#[pyo3(get)]
pub phrase: &'static str,
#[pyo3(get)]
pub content: Bytes,
#[pyo3(get)]
pub headers: HashMap<String, Bytes>,
}
#[pyclass]
#[derive(Clone)]
pub struct HttpClient {
client: hyper::Client<MatrixConnector>,
resolver: MatrixResolver,
}
impl HttpClient {
pub fn new() -> Result<Self, Error> {
let resolver = MatrixResolver::new()?;
let client =
hyper::Client::builder().build(MatrixConnector::with_resolver(resolver.clone()));
Ok(HttpClient { client, resolver })
}
pub async fn async_request(
&self,
url: String,
method: String,
headers: HashMap<Vec<u8>, Vec<Vec<u8>>>,
body: Option<Vec<u8>>,
) -> Result<MatrixResponse, Error> {
let uri: Uri = url.try_into()?;
let mut builder = Request::builder().method(&*method).uri(uri.clone());
for (key, values) in headers {
for value in values {
builder = builder.header(key.clone(), value);
}
}
if uri.scheme_str() == Some("matrix") {
let endpoints = self.resolver.resolve_server_name_from_uri(&uri).await?;
if let Some(endpoint) = endpoints.first() {
builder = builder.header("Host", &endpoint.host_header);
}
}
let request = if let Some(body) = body {
builder.body(Body::from(body))?
} else {
builder.body(Body::empty())?
};
let response = self.client.request(request).await?;
let code = response.status().as_u16();
let phrase = response.status().canonical_reason().unwrap_or_default();
let headers = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), Bytes(v.as_bytes().to_owned())))
.collect();
let body = response.into_body();
let bytes = hyper::body::to_bytes(body).await?;
let content = Bytes(bytes.to_vec());
Ok(MatrixResponse {
code,
phrase,
content,
headers,
})
}
}
#[pymethods]
impl HttpClient {
#[new]
fn py_new() -> Result<Self, Error> {
Self::new()
}
fn request<'a>(
&'a self,
py: Python<'a>,
url: String,
method: String,
headers: HashMap<Vec<u8>, Vec<Vec<u8>>>,
body: Option<Vec<u8>>,
) -> PyResult<&'a PyAny> {
pyo3::prepare_freethreaded_python();
let client = self.clone();
pyo3_asyncio::tokio::future_into_py(py, async move {
let resp = client.async_request(url, method, headers, body).await?;
Ok(resp)
})
}
}

432
rust/src/http/resolver.rs Normal file
View File

@@ -0,0 +1,432 @@
use std::collections::BTreeMap;
use std::future::Future;
use std::net::IpAddr;
use std::pin::Pin;
use std::str::FromStr;
use std::{
io::Cursor,
sync::{Arc, Mutex},
task::{self, Poll},
};
use anyhow::{bail, Error};
use futures::{FutureExt, TryFutureExt};
use futures_util::stream::StreamExt;
use http::Uri;
use hyper::client::connect::Connection;
use hyper::client::connect::{Connected, HttpConnector};
use hyper::server::conn::Http;
use hyper::service::Service;
use hyper::Client;
use hyper_tls::HttpsConnector;
use hyper_tls::MaybeHttpsStream;
use log::{debug, info};
use native_tls::TlsConnector;
use serde::Deserialize;
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use tokio::net::TcpStream;
use tokio_native_tls::TlsConnector as AsyncTlsConnector;
use trust_dns_resolver::error::ResolveErrorKind;
#[derive(Debug, Clone)]
pub struct Endpoint {
pub host: String,
pub port: u16,
pub host_header: String,
pub tls_name: String,
}
#[derive(Clone)]
pub struct MatrixResolver {
resolver: trust_dns_resolver::TokioAsyncResolver,
http_client: Client<HttpsConnector<HttpConnector>>,
}
impl MatrixResolver {
pub fn new() -> Result<MatrixResolver, Error> {
let http_client = hyper::Client::builder().build(HttpsConnector::new());
MatrixResolver::with_client(http_client)
}
pub fn with_client(
http_client: Client<HttpsConnector<HttpConnector>>,
) -> Result<MatrixResolver, Error> {
let resolver = trust_dns_resolver::TokioAsyncResolver::tokio_from_system_conf()?;
Ok(MatrixResolver {
resolver,
http_client,
})
}
/// Does SRV lookup
pub async fn resolve_server_name_from_uri(&self, uri: &Uri) -> Result<Vec<Endpoint>, Error> {
let host = uri.host().expect("URI has no host").to_string();
let port = uri.port_u16();
self.resolve_server_name_from_host_port(host, port).await
}
pub async fn resolve_server_name_from_host_port(
&self,
mut host: String,
mut port: Option<u16>,
) -> Result<Vec<Endpoint>, Error> {
let mut authority = if let Some(p) = port {
format!("{}:{}", host, p)
} else {
host.to_string()
};
// If a literal IP or includes port then we shortcircuit.
if host.parse::<IpAddr>().is_ok() || port.is_some() {
return Ok(vec![Endpoint {
host: host.to_string(),
port: port.unwrap_or(8448),
host_header: authority.to_string(),
tls_name: host.to_string(),
}]);
}
// Do well-known delegation lookup.
if let Some(server) = get_well_known(&self.http_client, &host).await {
let a = http::uri::Authority::from_str(&server.server)?;
host = a.host().to_string();
port = a.port_u16();
authority = a.to_string();
}
// If a literal IP or includes port then we shortcircuit.
if host.parse::<IpAddr>().is_ok() || port.is_some() {
return Ok(vec![Endpoint {
host: host.clone(),
port: port.unwrap_or(8448),
host_header: authority.to_string(),
tls_name: host.clone(),
}]);
}
let result = self
.resolver
.srv_lookup(format!("_matrix._tcp.{}", host))
.await;
let records = match result {
Ok(records) => records,
Err(err) => match err.kind() {
ResolveErrorKind::NoRecordsFound { .. } => {
return Ok(vec![Endpoint {
host: host.clone(),
port: 8448,
host_header: authority.to_string(),
tls_name: host.clone(),
}])
}
_ => return Err(err.into()),
},
};
let mut priority_map: BTreeMap<u16, Vec<_>> = BTreeMap::new();
let mut count = 0;
for record in records {
count += 1;
let priority = record.priority();
priority_map.entry(priority).or_default().push(record);
}
let mut results = Vec::with_capacity(count);
for (_priority, records) in priority_map {
// TODO: Correctly shuffle records
results.extend(records.into_iter().map(|record| Endpoint {
host: record.target().to_utf8(),
port: record.port(),
host_header: host.to_string(),
tls_name: host.to_string(),
}))
}
Ok(results)
}
}
async fn get_well_known<C>(http_client: &Client<C>, host: &str) -> Option<WellKnownServer>
where
C: Service<Uri> + Clone + Sync + Send + 'static,
C::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
C::Future: Unpin + Send,
C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
{
// TODO: Add timeout.
let uri = hyper::Uri::builder()
.scheme("https")
.authority(host)
.path_and_query("/.well-known/matrix/server")
.build()
.ok()?;
let mut body = http_client.get(uri).await.ok()?.into_body();
let mut vec = Vec::new();
while let Some(next) = body.next().await {
let chunk = next.ok()?;
vec.extend(chunk);
}
serde_json::from_slice(&vec).ok()?
}
#[derive(Deserialize)]
struct WellKnownServer {
#[serde(rename = "m.server")]
server: String,
}
#[derive(Clone)]
pub struct MatrixConnector {
resolver: MatrixResolver,
}
impl MatrixConnector {
pub fn with_resolver(resolver: MatrixResolver) -> MatrixConnector {
MatrixConnector { resolver }
}
}
impl Service<Uri> for MatrixConnector {
type Response = MaybeHttpsStream<TcpStream>;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
// This connector is always ready, but others might not be.
Poll::Ready(Ok(()))
}
fn call(&mut self, dst: Uri) -> Self::Future {
let resolver = self.resolver.clone();
if dst.scheme_str() != Some("matrix") {
debug!("Got non-matrix scheme");
return HttpsConnector::new()
.call(dst)
.map_err(|e| Error::msg(e))
.boxed();
}
async move {
let endpoints = resolver
.resolve_server_name_from_host_port(
dst.host().expect("hostname").to_string(),
dst.port_u16(),
)
.await?;
debug!("Got endpoints: {:?}", endpoints);
for endpoint in endpoints {
match try_connecting(&dst, &endpoint).await {
Ok(r) => return Ok(r),
// Errors here are not unexpected, and we just move on
// with our lives.
Err(e) => info!(
"Failed to connect to {} via {}:{} because {}",
dst.host().expect("hostname"),
endpoint.host,
endpoint.port,
e,
),
}
}
bail!(
"failed to resolve host: {:?} port {:?}",
dst.host(),
dst.port()
)
}
.boxed()
}
}
/// Attempts to connect to a particular endpoint.
async fn try_connecting(
dst: &Uri,
endpoint: &Endpoint,
) -> Result<MaybeHttpsStream<TcpStream>, Error> {
let tcp = TcpStream::connect((&endpoint.host as &str, endpoint.port)).await?;
let connector: AsyncTlsConnector = if dst.host().expect("hostname").contains("localhost") {
TlsConnector::builder()
.danger_accept_invalid_certs(true)
.build()?
.into()
} else {
TlsConnector::new().unwrap().into()
};
let tls = connector.connect(&endpoint.tls_name, tcp).await?;
Ok(tls.into())
}
/// A connector that reutrns a connection which returns 200 OK to all connections.
#[derive(Clone)]
pub struct TestConnector;
impl Service<Uri> for TestConnector {
type Response = TestConnection;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
// This connector is always ready, but others might not be.
Poll::Ready(Ok(()))
}
fn call(&mut self, _dst: Uri) -> Self::Future {
let (client, server) = TestConnection::double_ended();
{
let service = hyper::service::service_fn(|_| async move {
Ok(hyper::Response::new(hyper::Body::from("Hello World")))
as Result<_, hyper::http::Error>
});
let fut = Http::new().serve_connection(server, service);
tokio::spawn(fut);
}
futures::future::ok(client).boxed()
}
}
#[derive(Default)]
struct TestConnectionInner {
outbound_buffer: Cursor<Vec<u8>>,
inbound_buffer: Cursor<Vec<u8>>,
wakers: Vec<futures::task::Waker>,
}
/// A in memory connection for use with tests.
#[derive(Clone, Default)]
pub struct TestConnection {
inner: Arc<Mutex<TestConnectionInner>>,
direction: bool,
}
impl TestConnection {
pub fn double_ended() -> (TestConnection, TestConnection) {
let inner: Arc<Mutex<TestConnectionInner>> = Arc::default();
let a = TestConnection {
inner: inner.clone(),
direction: false,
};
let b = TestConnection {
inner,
direction: true,
};
(a, b)
}
}
impl AsyncRead for TestConnection {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<Result<(), std::io::Error>> {
let mut conn = self.inner.lock().expect("mutex");
let buffer = if self.direction {
&mut conn.inbound_buffer
} else {
&mut conn.outbound_buffer
};
let bytes_read = std::io::Read::read(buffer, buf.initialize_unfilled())?;
buf.advance(bytes_read);
if bytes_read > 0 {
Poll::Ready(Ok(()))
} else {
conn.wakers.push(cx.waker().clone());
Poll::Pending
}
}
}
impl AsyncWrite for TestConnection {
fn poll_write(
self: Pin<&mut Self>,
_cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
let mut conn = self.inner.lock().expect("mutex");
if self.direction {
conn.outbound_buffer.get_mut().extend_from_slice(buf);
} else {
conn.inbound_buffer.get_mut().extend_from_slice(buf);
}
for waker in conn.wakers.drain(..) {
waker.wake()
}
Poll::Ready(Ok(buf.len()))
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let mut conn = self.inner.lock().expect("mutex");
if self.direction {
Pin::new(&mut conn.outbound_buffer).poll_flush(cx)
} else {
Pin::new(&mut conn.inbound_buffer).poll_flush(cx)
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
) -> Poll<Result<(), std::io::Error>> {
let mut conn = self.inner.lock().expect("mutex");
if self.direction {
Pin::new(&mut conn.outbound_buffer).poll_shutdown(cx)
} else {
Pin::new(&mut conn.inbound_buffer).poll_shutdown(cx)
}
}
}
impl Connection for TestConnection {
fn connected(&self) -> Connected {
Connected::new()
}
}
#[tokio::test]
async fn test_memory_connection() {
let client: hyper::Client<_, hyper::Body> = hyper::Client::builder().build(TestConnector);
let response = client
.get("http://localhost".parse().unwrap())
.await
.unwrap();
assert!(response.status().is_success());
let bytes = hyper::body::to_bytes(response.into_body()).await.unwrap();
assert_eq!(&bytes[..], b"Hello World");
}

View File

@@ -1,5 +1,6 @@
use pyo3::prelude::*;
pub mod http;
pub mod push;
/// Returns the hash of all the rust source files at the time it was compiled.
@@ -26,6 +27,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?;
push::register_module(py, m)?;
http::register_module(py, m)?;
Ok(())
}

View File

@@ -27,7 +27,7 @@ import time
import urllib.request
from os import path
from tempfile import TemporaryDirectory
from typing import Any, List, Optional, cast
from typing import Any, List, Optional
import attr
import click
@@ -174,9 +174,7 @@ def _prepare() -> None:
click.get_current_context().abort()
# Switch to the release branch.
# Cast safety: parse() won't return a version.LegacyVersion from our
# version string format.
parsed_new_version = cast(version.Version, version.parse(new_version))
parsed_new_version = version.parse(new_version)
# We assume for debian changelogs that we only do RCs or full releases.
assert not parsed_new_version.is_devrelease

View File

@@ -0,0 +1,16 @@
from typing import Dict, List, Optional
class MatrixResponse:
code: int
phrase: str
content: bytes
headers: Dict[str, str]
class HttpClient:
async def request(
self,
url: str,
method: str,
headers: Dict[bytes, List[bytes]],
body: Optional[bytes],
) -> MatrixResponse: ...

View File

@@ -29,7 +29,7 @@ if sys.version_info < (3, 7):
sys.exit(1)
# Allow using the asyncio reactor via env var.
if strtobool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", "0")):
if strtobool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", "0")) or True:
from incremental import Version
import twisted

View File

@@ -245,7 +245,9 @@ class ApplicationService:
return True
# likewise with the room's aliases (if it has any)
alias_list = await store.get_aliases_for_room(room_id)
alias_list = await store.get_aliases_for_room(
room_id, on_invalidate=cache_context.invalidate
)
for alias in alias_list:
if self.is_room_alias_in_namespace(alias):
return True
@@ -311,7 +313,9 @@ class ApplicationService:
# Find all the rooms the sender is in
if self.is_interested_in_user(user_id.to_string()):
return True
room_ids = await store.get_rooms_for_user(user_id.to_string())
room_ids = await store.get_rooms_for_user(
user_id.to_string(), on_invalidate=cache_context.invalidate
)
# Then find out if the appservice is interested in any of those rooms
for room_id in room_ids:

View File

@@ -33,6 +33,9 @@ def validate_config(
config: the configuration value to be validated
config_path: the path within the config file. This will be used as a basis
for the error message.
Raises:
ConfigError, if validation fails.
"""
try:
jsonschema.validate(config, json_schema)

View File

@@ -13,12 +13,13 @@
# limitations under the License.
import logging
from typing import Any, Iterable
from typing import Any, Iterable, Optional, Tuple
from synapse.api.constants import EventTypes
from synapse.config._base import Config, ConfigError
from synapse.config._util import validate_config
from synapse.types import JsonDict
from synapse.types.state import StateFilter
logger = logging.getLogger(__name__)
@@ -26,16 +27,20 @@ logger = logging.getLogger(__name__)
class ApiConfig(Config):
section = "api"
room_prejoin_state: StateFilter
track_puppetted_users_ips: bool
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
validate_config(_MAIN_SCHEMA, config, ())
self.room_prejoin_state = list(self._get_prejoin_state_types(config))
self.room_prejoin_state = StateFilter.from_types(
self._get_prejoin_state_entries(config)
)
self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)
def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]:
"""Get the event types to include in the prejoin state
Parses the config and returns an iterable of the event types to be included.
"""
def _get_prejoin_state_entries(
self, config: JsonDict
) -> Iterable[Tuple[str, Optional[str]]]:
"""Get the event types and state keys to include in the prejoin state."""
room_prejoin_state_config = config.get("room_prejoin_state") or {}
# backwards-compatibility support for room_invite_state_types
@@ -50,33 +55,39 @@ class ApiConfig(Config):
logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING)
yield from config["room_invite_state_types"]
for event_type in config["room_invite_state_types"]:
yield event_type, None
return
if not room_prejoin_state_config.get("disable_default_event_types"):
yield from _DEFAULT_PREJOIN_STATE_TYPES
yield from _DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS
yield from room_prejoin_state_config.get("additional_event_types", [])
for entry in room_prejoin_state_config.get("additional_event_types", []):
if isinstance(entry, str):
yield entry, None
else:
yield entry
_ROOM_INVITE_STATE_TYPES_WARNING = """\
WARNING: The 'room_invite_state_types' configuration setting is now deprecated,
and replaced with 'room_prejoin_state'. New features may not work correctly
unless 'room_invite_state_types' is removed. See the sample configuration file for
details of 'room_prejoin_state'.
unless 'room_invite_state_types' is removed. See the config documentation at
https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#room_prejoin_state
for details of 'room_prejoin_state'.
--------------------------------------------------------------------------------
"""
_DEFAULT_PREJOIN_STATE_TYPES = [
EventTypes.JoinRules,
EventTypes.CanonicalAlias,
EventTypes.RoomAvatar,
EventTypes.RoomEncryption,
EventTypes.Name,
_DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS = [
(EventTypes.JoinRules, ""),
(EventTypes.CanonicalAlias, ""),
(EventTypes.RoomAvatar, ""),
(EventTypes.RoomEncryption, ""),
(EventTypes.Name, ""),
# Per MSC1772.
EventTypes.Create,
(EventTypes.Create, ""),
# Per MSC3173.
EventTypes.Topic,
(EventTypes.Topic, ""),
]
@@ -90,7 +101,17 @@ _ROOM_PREJOIN_STATE_CONFIG_SCHEMA = {
"disable_default_event_types": {"type": "boolean"},
"additional_event_types": {
"type": "array",
"items": {"type": "string"},
"items": {
"oneOf": [
{"type": "string"},
{
"type": "array",
"items": {"type": "string"},
"minItems": 2,
"maxItems": 2,
},
],
},
},
},
},

View File

@@ -28,8 +28,14 @@ from typing import (
)
import attr
from canonicaljson import encode_canonical_json
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.constants import (
MAX_PDU_SIZE,
EventContentFields,
EventTypes,
RelationTypes,
)
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.types import JsonDict
@@ -674,3 +680,27 @@ def validate_canonicaljson(value: Any) -> None:
elif not isinstance(value, (bool, str)) and value is not None:
# Other potential JSON values (bool, None, str) are safe.
raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON)
def maybe_upsert_event_field(
event: EventBase, container: JsonDict, key: str, value: object
) -> bool:
"""Upsert an event field, but only if this doesn't make the event too large.
Returns true iff the upsert took place.
"""
if key in container:
old_value: object = container[key]
container[key] = value
# NB: here and below, we assume that passing a non-None `time_now` argument to
# get_pdu_json doesn't increase the size of the encoded result.
upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE
if not upsert_okay:
container[key] = old_value
else:
container[key] = value
upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE
if not upsert_okay:
del container[key]
return upsert_okay

View File

@@ -50,6 +50,7 @@ from synapse.event_auth import validate_event_for_room_version
from synapse.events import EventBase, relation_from_event
from synapse.events.builder import EventBuilder
from synapse.events.snapshot import EventContext
from synapse.events.utils import maybe_upsert_event_field
from synapse.events.validator import EventValidator
from synapse.handlers.directory import DirectoryHandler
from synapse.logging import opentracing
@@ -1739,12 +1740,15 @@ class EventCreationHandler:
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.INVITE:
event.unsigned[
"invite_room_state"
] = await self.store.get_stripped_room_state_from_event_context(
context,
self.room_prejoin_state_types,
membership_user_id=event.sender,
maybe_upsert_event_field(
event,
event.unsigned,
"invite_room_state",
await self.store.get_stripped_room_state_from_event_context(
context,
self.room_prejoin_state_types,
membership_user_id=event.sender,
),
)
invitee = UserID.from_string(event.state_key)
@@ -1762,11 +1766,14 @@ class EventCreationHandler:
event.signatures.update(returned_invite.signatures)
if event.content["membership"] == Membership.KNOCK:
event.unsigned[
"knock_room_state"
] = await self.store.get_stripped_room_state_from_event_context(
context,
self.room_prejoin_state_types,
maybe_upsert_event_field(
event,
event.unsigned,
"knock_room_state",
await self.store.get_stripped_room_state_from_event_context(
context,
self.room_prejoin_state_types,
),
)
if event.type == EventTypes.Redaction:

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import asyncio
import cgi
import codecs
import logging
@@ -42,14 +43,18 @@ from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
from typing_extensions import Literal
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed
from twisted.internet.testing import StringTransport
from twisted.python.failure import Failure
from twisted.web.client import Response, ResponseDone, ResponseFailed
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse
from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer, IResponse
import synapse.metrics
import synapse.util.retryutils
@@ -75,6 +80,7 @@ from synapse.http.types import QueryParams
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.synapse_rust.http import HttpClient
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred
@@ -199,6 +205,33 @@ class JsonParser(ByteParser[Union[JsonDict, list]]):
return json_decoder.decode(self._buffer.getvalue())
@attr.s(auto_attribs=True)
@implementer(IResponse)
class RustResponse:
version: tuple
code: int
phrase: bytes
headers: Headers
length: Union[int, UNKNOWN_LENGTH]
# request: Optional[IClientRequest]
# previousResponse: Optional[IResponse]
_data: bytes
def deliverBody(self, protocol: Protocol):
protocol.dataReceived(self._data)
protocol.connectionLost(Failure(ResponseDone("Response body fully received")))
def setPreviousResponse(self, response: IResponse):
pass
async def _handle_response(
reactor: IReactorTime,
timeout_sec: float,
@@ -372,6 +405,8 @@ class MatrixFederationHttpClient:
self._sleeper = AwakenableSleeper(self.reactor)
self._rust_client = HttpClient()
def wake_destination(self, destination: str) -> None:
"""Called when the remote server may have come back online."""
@@ -556,11 +591,8 @@ class MatrixFederationHttpClient:
destination_bytes, method_bytes, url_to_sign_bytes, json
)
data = encode_canonical_json(json)
producer: Optional[IBodyProducer] = QuieterFileBodyProducer(
BytesIO(data), cooperator=self._cooperator
)
else:
producer = None
data = None
auth_headers = self.build_auth_headers(
destination_bytes, method_bytes, url_to_sign_bytes
)
@@ -591,23 +623,33 @@ class MatrixFederationHttpClient:
# * The `Deferred` that joins the forks back together is
# wrapped in `make_deferred_yieldable` to restore the
# logging context regardless of the path taken.
request_deferred = run_in_background(
self.agent.request,
method_bytes,
url_bytes,
headers=Headers(headers_dict),
bodyProducer=producer,
)
request_deferred = timeout_deferred(
request_deferred,
timeout=_sec_timeout,
reactor=self.reactor,
)
# request_deferred = run_in_background(
# self._rust_client.request,
# url_str,
# request.method,
# headers_dict,
# data,
# )
# request_deferred = timeout_deferred(
# request_deferred,
# timeout=_sec_timeout,
# reactor=self.reactor,
# )
response = await make_deferred_yieldable(request_deferred)
# response = await make_deferred_yieldable(request_deferred)
response_d = run_in_background(
self._rust_client.request,
url_str,
request.method,
headers_dict,
data,
)
response = await make_deferred_yieldable(response_d)
except DNSLookupError as e:
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
except Exception as e:
logger.exception("ERROR")
raise RequestSendFailed(e, can_retry=True) from e
incoming_responses_counter.labels(
@@ -615,7 +657,7 @@ class MatrixFederationHttpClient:
).inc()
set_tag(tags.HTTP_STATUS_CODE, response.code)
response_phrase = response.phrase.decode("ascii", errors="replace")
response_phrase = response.phrase
if 200 <= response.code < 300:
logger.debug(
@@ -635,25 +677,7 @@ class MatrixFederationHttpClient:
)
# :'(
# Update transactions table?
d = treq.content(response)
d = timeout_deferred(
d, timeout=_sec_timeout, reactor=self.reactor
)
try:
body = await make_deferred_yieldable(d)
except Exception as e:
# Eh, we're already going to raise an exception so lets
# ignore if this fails.
logger.warning(
"{%s} [%s] Failed to get error response: %s %s: %s",
request.txn_id,
request.destination,
request.method,
url_str,
_flatten_response_never_received(e),
)
body = None
body = response.content
exc = HttpResponseException(
response.code, response_phrase, body
@@ -715,7 +739,19 @@ class MatrixFederationHttpClient:
_flatten_response_never_received(e),
)
raise
return response
headers = Headers()
for key, value in response.headers.items():
headers.addRawHeader(key, value)
return RustResponse(
("HTTP", 1, 1),
response.code,
response.phrase.encode("ascii"),
headers,
UNKNOWN_LENGTH,
response.content,
)
def build_auth_headers(
self,

View File

@@ -26,6 +26,7 @@ import logging
import threading
import typing
import warnings
from asyncio import Future
from types import TracebackType
from typing import (
TYPE_CHECKING,
@@ -814,6 +815,8 @@ def run_in_background( # type: ignore[misc]
res = defer.ensureDeferred(res)
elif isinstance(res, defer.Deferred):
pass
elif isinstance(res, Future):
res = defer.Deferred.fromFuture(res)
elif isinstance(res, Awaitable):
# `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
# or `Future` from `make_awaitable`.

View File

@@ -667,7 +667,8 @@ class DatabasePool:
)
# also check variables referenced in func's closure
if inspect.isfunction(func):
f = cast(types.FunctionType, func)
# Keep the cast for now---it helps PyCharm to understand what `func` is.
f = cast(types.FunctionType, func) # type: ignore[redundant-cast]
if f.__closure__:
for i, cell in enumerate(f.__closure__):
if inspect.isgenerator(cell.cell_contents):

View File

@@ -16,11 +16,11 @@ import logging
import threading
import weakref
from enum import Enum, auto
from itertools import chain
from typing import (
TYPE_CHECKING,
Any,
Collection,
Container,
Dict,
Iterable,
List,
@@ -76,6 +76,7 @@ from synapse.storage.util.id_generators import (
)
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
@@ -879,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore):
async def get_stripped_room_state_from_event_context(
self,
context: EventContext,
state_types_to_include: Container[str],
state_keys_to_include: StateFilter,
membership_user_id: Optional[str] = None,
) -> List[JsonDict]:
"""
@@ -892,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore):
Args:
context: The event context to retrieve state of the room from.
state_types_to_include: The type of state events to include.
state_keys_to_include: The state events to include, for each event type.
membership_user_id: An optional user ID to include the stripped membership state
events of. This is useful when generating the stripped state of a room for
invites. We want to send membership events of the inviter, so that the
@@ -901,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore):
Returns:
A list of dictionaries, each representing a stripped state event from the room.
"""
current_state_ids = await context.get_current_state_ids()
if membership_user_id:
types = chain(
state_keys_to_include.to_types(),
[(EventTypes.Member, membership_user_id)],
)
filter = StateFilter.from_types(types)
else:
filter = state_keys_to_include
selected_state_ids = await context.get_current_state_ids(filter)
# We know this event is not an outlier, so this must be
# non-None.
assert current_state_ids is not None
assert selected_state_ids is not None
# The state to include
state_to_include_ids = [
e_id
for k, e_id in current_state_ids.items()
if k[0] in state_types_to_include
or (membership_user_id and k == (EventTypes.Member, membership_user_id))
]
# Confusingly, get_current_state_events may return events that are discarded by
# the filter, if they're in context._state_delta_due_to_event. Strip these away.
selected_state_ids = filter.filter_state(selected_state_ids)
state_to_include = await self.get_events(state_to_include_ids)
state_to_include = await self.get_events(selected_state_ids.values())
return [
{

View File

@@ -77,7 +77,7 @@ class PostgresEngine(
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
self._version = cast(int, db_conn.server_version)
self._version = db_conn.server_version
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
# Are we on a supported PostgreSQL version?

View File

@@ -118,6 +118,15 @@ class StateFilter:
)
)
def to_types(self) -> Iterable[Tuple[str, Optional[str]]]:
"""The inverse to `from_types`."""
for (event_type, state_keys) in self.types.items():
if state_keys is None:
yield event_type, None
else:
for state_key in state_keys:
yield event_type, state_key
@staticmethod
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
"""Creates a filter that returns all non-member events, plus the member
@@ -343,6 +352,15 @@ class StateFilter:
for s in state_keys
]
def wildcard_types(self) -> List[str]:
"""Returns a list of event types which require us to fetch all state keys.
This will be empty unless `has_wildcards` returns True.
Returns:
A list of event types.
"""
return [t for t, state_keys in self.types.items() if state_keys is None]
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
"""Return the filter split into two: one which assumes it's exclusively
matching against member state, and one which assumes it's matching

145
tests/config/test_api.py Normal file
View File

@@ -0,0 +1,145 @@
from unittest import TestCase as StdlibTestCase
import yaml
from synapse.config import ConfigError
from synapse.config.api import ApiConfig
from synapse.types.state import StateFilter
DEFAULT_PREJOIN_STATE_PAIRS = {
("m.room.join_rules", ""),
("m.room.canonical_alias", ""),
("m.room.avatar", ""),
("m.room.encryption", ""),
("m.room.name", ""),
("m.room.create", ""),
("m.room.topic", ""),
}
class TestRoomPrejoinState(StdlibTestCase):
def read_config(self, source: str) -> ApiConfig:
config = ApiConfig()
config.read_config(yaml.safe_load(source))
return config
def test_no_prejoin_state(self) -> None:
config = self.read_config("foo: bar")
self.assertFalse(config.room_prejoin_state.has_wildcards())
self.assertEqual(
set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS
)
def test_disable_default_event_types(self) -> None:
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
"""
)
self.assertEqual(config.room_prejoin_state, StateFilter.none())
def test_event_without_state_key(self) -> None:
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
additional_event_types:
- foo
"""
)
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
self.assertEqual(config.room_prejoin_state.concrete_types(), [])
def test_event_with_specific_state_key(self) -> None:
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
additional_event_types:
- [foo, bar]
"""
)
self.assertFalse(config.room_prejoin_state.has_wildcards())
self.assertEqual(
set(config.room_prejoin_state.concrete_types()),
{("foo", "bar")},
)
def test_repeated_event_with_specific_state_key(self) -> None:
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
additional_event_types:
- [foo, bar]
- [foo, baz]
"""
)
self.assertFalse(config.room_prejoin_state.has_wildcards())
self.assertEqual(
set(config.room_prejoin_state.concrete_types()),
{("foo", "bar"), ("foo", "baz")},
)
def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
additional_event_types:
- [foo, bar]
- foo
"""
)
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
self.assertEqual(config.room_prejoin_state.concrete_types(), [])
config = self.read_config(
"""
room_prejoin_state:
disable_default_event_types: true
additional_event_types:
- foo
- [foo, bar]
"""
)
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
self.assertEqual(config.room_prejoin_state.concrete_types(), [])
def test_bad_event_type_entry_raises(self) -> None:
with self.assertRaises(ConfigError):
self.read_config(
"""
room_prejoin_state:
additional_event_types:
- []
"""
)
with self.assertRaises(ConfigError):
self.read_config(
"""
room_prejoin_state:
additional_event_types:
- [a]
"""
)
with self.assertRaises(ConfigError):
self.read_config(
"""
room_prejoin_state:
additional_event_types:
- [a, b, c]
"""
)
with self.assertRaises(ConfigError):
self.read_config(
"""
room_prejoin_state:
additional_event_types:
- [true, 1.23]
"""
)

View File

@@ -12,19 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest as stdlib_unittest
from synapse.api.constants import EventContentFields
from synapse.api.room_versions import RoomVersions
from synapse.events import make_event_from_dict
from synapse.events.utils import (
SerializeEventConfig,
copy_and_fixup_power_levels_contents,
maybe_upsert_event_field,
prune_event,
serialize_event,
)
from synapse.util.frozenutils import freeze
from tests import unittest
def MockEvent(**kwargs):
if "event_id" not in kwargs:
@@ -34,7 +35,31 @@ def MockEvent(**kwargs):
return make_event_from_dict(kwargs)
class PruneEventTestCase(unittest.TestCase):
class TestMaybeUpsertEventField(stdlib_unittest.TestCase):
def test_update_okay(self) -> None:
event = make_event_from_dict({"event_id": "$1234"})
success = maybe_upsert_event_field(event, event.unsigned, "key", "value")
self.assertTrue(success)
self.assertEqual(event.unsigned["key"], "value")
def test_update_not_okay(self) -> None:
event = make_event_from_dict({"event_id": "$1234"})
LARGE_STRING = "a" * 100_000
success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
self.assertFalse(success)
self.assertNotIn("key", event.unsigned)
def test_update_not_okay_leaves_original_value(self) -> None:
event = make_event_from_dict(
{"event_id": "$1234", "unsigned": {"key": "value"}}
)
LARGE_STRING = "a" * 100_000
success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
self.assertFalse(success)
self.assertEqual(event.unsigned["key"], "value")
class PruneEventTestCase(stdlib_unittest.TestCase):
def run_test(self, evdict, matchdict, **kwargs):
"""
Asserts that a new event constructed with `evdict` will look like
@@ -391,7 +416,7 @@ class PruneEventTestCase(unittest.TestCase):
)
class SerializeEventTestCase(unittest.TestCase):
class SerializeEventTestCase(stdlib_unittest.TestCase):
def serialize(self, ev, fields):
return serialize_event(
ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
@@ -513,7 +538,7 @@ class SerializeEventTestCase(unittest.TestCase):
)
class CopyPowerLevelsContentTestCase(unittest.TestCase):
class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
def setUp(self) -> None:
self.test_content = {
"ban": 50,

View File

@@ -26,7 +26,7 @@ from synapse.types import JsonDict, RoomID, StateMap, UserID
from synapse.types.state import StateFilter
from synapse.util import Clock
from tests.unittest import HomeserverTestCase, TestCase
from tests.unittest import HomeserverTestCase
logger = logging.getLogger(__name__)
@@ -494,624 +494,3 @@ class StateStoreTestCase(HomeserverTestCase):
self.assertEqual(is_all, True)
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
class StateFilterDifferenceTestCase(TestCase):
def assert_difference(
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
) -> None:
self.assertEqual(
minuend.approx_difference(subtrahend),
expected,
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
)
def test_state_filter_difference_no_include_other_minus_no_include_other(
self,
) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b do not have the
include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze({EventTypes.Create: None}, include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=False,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.CanonicalAlias: {""}},
include_others=False,
),
)
# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)
def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only a has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Create: None,
EventTypes.Member: set(),
EventTypes.CanonicalAlias: set(),
},
include_others=True,
),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
# This also shows that the resultant state filter is normalised.
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
StateFilter(types=frozendict(), include_others=True),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter(
types=frozendict(),
include_others=True,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.CanonicalAlias: {""},
EventTypes.Member: set(),
},
include_others=True,
),
)
# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)
def test_state_filter_difference_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b have the include_others
flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)
# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
EventTypes.Create: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)
def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only b has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=True,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)
# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)
def test_state_filter_difference_simple_cases(self) -> None:
"""
Tests some very simple cases of the StateFilter approx_difference,
that are not explicitly tested by the more in-depth tests.
"""
self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
self.assert_difference(
StateFilter.all(),
StateFilter.none(),
StateFilter.all(),
)
class StateFilterTestCase(TestCase):
def test_return_expanded(self) -> None:
"""
Tests the behaviour of the return_expanded() function that expands
StateFilters to include more state types (for the sake of cache hit rate).
"""
self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
# Concrete-only state filters stay the same
# (Case: mixed filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": {""},
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": {""},
},
include_others=False,
),
)
# Concrete-only state filters stay the same
# (Case: non-member-only filter)
self.assertEqual(
StateFilter.freeze(
{"some.other.state.type": {""}}, include_others=False
).return_expanded(),
StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
)
# Concrete-only state filters stay the same
# (Case: member-only filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
},
include_others=False,
),
)
# Wildcard member-only state filters stay the same
self.assertEqual(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# If there is a wildcard in the non-member portion of the filter,
# it's expanded to include ALL non-member events.
# (Case: mixed filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": None,
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:test", "@alicia:test"}},
include_others=True,
),
)
# If there is a wildcard in the non-member portion of the filter,
# it's expanded to include ALL non-member events.
# (Case: non-member-only filter)
self.assertEqual(
StateFilter.freeze(
{
"some.other.state.type": None,
},
include_others=False,
).return_expanded(),
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
)
self.assertEqual(
StateFilter.freeze(
{
"some.other.state.type": None,
"yet.another.state.type": {"wombat"},
},
include_others=False,
).return_expanded(),
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
)

0
tests/types/__init__.py Normal file
View File

627
tests/types/test_state.py Normal file
View File

@@ -0,0 +1,627 @@
from frozendict import frozendict
from synapse.api.constants import EventTypes
from synapse.types.state import StateFilter
from tests.unittest import TestCase
class StateFilterDifferenceTestCase(TestCase):
def assert_difference(
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
) -> None:
self.assertEqual(
minuend.approx_difference(subtrahend),
expected,
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
)
def test_state_filter_difference_no_include_other_minus_no_include_other(
self,
) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b do not have the
include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze({EventTypes.Create: None}, include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=False,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.CanonicalAlias: {""}},
include_others=False,
),
)
# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
)
def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only a has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Create: None,
EventTypes.Member: set(),
EventTypes.CanonicalAlias: set(),
},
include_others=True,
),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
# This also shows that the resultant state filter is normalised.
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
StateFilter(types=frozendict(), include_others=True),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter(
types=frozendict(),
include_others=True,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.CanonicalAlias: {""},
EventTypes.Member: set(),
},
include_others=True,
),
)
# (specific state keys) - (specific state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
)
def test_state_filter_difference_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), both a and b have the include_others
flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=False,
),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)
# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
EventTypes.Create: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
EventTypes.Create: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
EventTypes.Create: {""},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)
def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
"""
Tests the StateFilter.approx_difference method
where, in a.approx_difference(b), only b has the include_others flag set.
"""
# (wildcard on state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.Create: None},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
include_others=True,
),
StateFilter(types=frozendict(), include_others=False),
)
# (wildcard on state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:spqr"}},
include_others=True,
),
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
)
# (wildcard on state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# (specific state keys) - (wildcard on state keys):
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=True,
),
StateFilter(
types=frozendict(),
include_others=False,
),
)
# (specific state keys) - (specific state keys)
# This one is an over-approximation because we can't represent
# 'all state keys except a few named examples'
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr"},
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@spqr:spqr"},
},
include_others=False,
),
)
# (specific state keys) - (no state keys)
self.assert_difference(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
EventTypes.CanonicalAlias: {""},
},
include_others=False,
),
StateFilter.freeze(
{
EventTypes.Member: set(),
},
include_others=True,
),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
},
include_others=False,
),
)
def test_state_filter_difference_simple_cases(self) -> None:
"""
Tests some very simple cases of the StateFilter approx_difference,
that are not explicitly tested by the more in-depth tests.
"""
self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
self.assert_difference(
StateFilter.all(),
StateFilter.none(),
StateFilter.all(),
)
class StateFilterTestCase(TestCase):
def test_return_expanded(self) -> None:
"""
Tests the behaviour of the return_expanded() function that expands
StateFilters to include more state types (for the sake of cache hit rate).
"""
self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
# Concrete-only state filters stay the same
# (Case: mixed filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": {""},
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": {""},
},
include_others=False,
),
)
# Concrete-only state filters stay the same
# (Case: non-member-only filter)
self.assertEqual(
StateFilter.freeze(
{"some.other.state.type": {""}}, include_others=False
).return_expanded(),
StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
)
# Concrete-only state filters stay the same
# (Case: member-only filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
},
include_others=False,
),
)
# Wildcard member-only state filters stay the same
self.assertEqual(
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{EventTypes.Member: None},
include_others=False,
),
)
# If there is a wildcard in the non-member portion of the filter,
# it's expanded to include ALL non-member events.
# (Case: mixed filter)
self.assertEqual(
StateFilter.freeze(
{
EventTypes.Member: {"@wombat:test", "@alicia:test"},
"some.other.state.type": None,
},
include_others=False,
).return_expanded(),
StateFilter.freeze(
{EventTypes.Member: {"@wombat:test", "@alicia:test"}},
include_others=True,
),
)
# If there is a wildcard in the non-member portion of the filter,
# it's expanded to include ALL non-member events.
# (Case: non-member-only filter)
self.assertEqual(
StateFilter.freeze(
{
"some.other.state.type": None,
},
include_others=False,
).return_expanded(),
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
)
self.assertEqual(
StateFilter.freeze(
{
"some.other.state.type": None,
"yet.another.state.type": {"wombat"},
},
include_others=False,
).return_expanded(),
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
)