Compare commits
7 Commits
mv/unbind-
...
erikj/rust
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f5817281f8 | ||
|
|
87406aa5d3 | ||
|
|
6842974391 | ||
|
|
c93ef61fa3 | ||
|
|
e2a1adbf5d | ||
|
|
3d87847ecc | ||
|
|
7982891794 |
1114
Cargo.lock
generated
1114
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
1
changelog.d/14642.feature
Normal file
1
changelog.d/14642.feature
Normal file
@@ -0,0 +1 @@
|
||||
Allow selecting "prejoin" events by state keys in addition to event types.
|
||||
1
changelog.d/14670.bugfix
Normal file
1
changelog.d/14670.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix bugs introduced in 1.55.0 and 1.69.0 where application services would not be notified of events in the correct rooms, due to stale caches.
|
||||
1
changelog.d/14671.misc
Normal file
1
changelog.d/14671.misc
Normal file
@@ -0,0 +1 @@
|
||||
Improve type hints.
|
||||
@@ -2501,32 +2501,53 @@ Config settings related to the client/server API
|
||||
---
|
||||
### `room_prejoin_state`
|
||||
|
||||
Controls for the state that is shared with users who receive an invite
|
||||
to a room. By default, the following state event types are shared with users who
|
||||
receive invites to the room:
|
||||
- m.room.join_rules
|
||||
- m.room.canonical_alias
|
||||
- m.room.avatar
|
||||
- m.room.encryption
|
||||
- m.room.name
|
||||
- m.room.create
|
||||
- m.room.topic
|
||||
This setting controls the state that is shared with users upon receiving an
|
||||
invite to a room, or in reply to a knock on a room. By default, the following
|
||||
state events are shared with users:
|
||||
|
||||
- `m.room.join_rules`
|
||||
- `m.room.canonical_alias`
|
||||
- `m.room.avatar`
|
||||
- `m.room.encryption`
|
||||
- `m.room.name`
|
||||
- `m.room.create`
|
||||
- `m.room.topic`
|
||||
|
||||
To change the default behavior, use the following sub-options:
|
||||
* `disable_default_event_types`: set to true to disable the above defaults. If this
|
||||
is enabled, only the event types listed in `additional_event_types` are shared.
|
||||
Defaults to false.
|
||||
* `additional_event_types`: Additional state event types to share with users when they are invited
|
||||
to a room. By default, this list is empty (so only the default event types are shared).
|
||||
* `disable_default_event_types`: boolean. Set to `true` to disable the above
|
||||
defaults. If this is enabled, only the event types listed in
|
||||
`additional_event_types` are shared. Defaults to `false`.
|
||||
* `additional_event_types`: A list of additional state events to include in the
|
||||
events to be shared. By default, this list is empty (so only the default event
|
||||
types are shared).
|
||||
|
||||
Each entry in this list should be either a single string or a list of two
|
||||
strings.
|
||||
* A standalone string `t` represents all events with type `t` (i.e.
|
||||
with no restrictions on state keys).
|
||||
* A pair of strings `[t, s]` represents a single event with type `t` and
|
||||
state key `s`. The same type can appear in two entries with different state
|
||||
keys: in this situation, both state keys are included in prejoin state.
|
||||
|
||||
Example configuration:
|
||||
```yaml
|
||||
room_prejoin_state:
|
||||
disable_default_event_types: true
|
||||
disable_default_event_types: false
|
||||
additional_event_types:
|
||||
- org.example.custom.event.type
|
||||
- m.room.join_rules
|
||||
# Share all events of type `org.example.custom.event.typeA`
|
||||
- org.example.custom.event.typeA
|
||||
# Share only events of type `org.example.custom.event.typeB` whose
|
||||
# state_key is "foo"
|
||||
- ["org.example.custom.event.typeB", "foo"]
|
||||
# Share only events of type `org.example.custom.event.typeC` whose
|
||||
# state_key is "bar" or "baz"
|
||||
- ["org.example.custom.event.typeC", "bar"]
|
||||
- ["org.example.custom.event.typeC", "baz"]
|
||||
```
|
||||
|
||||
*Changed in Synapse 1.74:* admins can filter the events in prejoin state based
|
||||
on their state key.
|
||||
|
||||
---
|
||||
### `track_puppeted_user_ips`
|
||||
|
||||
|
||||
13
mypy.ini
13
mypy.ini
@@ -12,6 +12,7 @@ local_partial_types = True
|
||||
no_implicit_optional = True
|
||||
disallow_untyped_defs = True
|
||||
strict_equality = True
|
||||
warn_redundant_casts = True
|
||||
|
||||
files =
|
||||
docker/,
|
||||
@@ -88,6 +89,12 @@ disallow_untyped_defs = False
|
||||
[mypy-tests.*]
|
||||
disallow_untyped_defs = False
|
||||
|
||||
[mypy-tests.config.test_api]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.federation.transport.test_client]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.handlers.test_sso]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
@@ -100,7 +107,7 @@ disallow_untyped_defs = True
|
||||
[mypy-tests.push.test_bulk_push_rule_evaluator]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.test_server]
|
||||
[mypy-tests.rest.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.state.test_profile]
|
||||
@@ -109,10 +116,10 @@ disallow_untyped_defs = True
|
||||
[mypy-tests.storage.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.rest.*]
|
||||
[mypy-tests.test_server]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.federation.transport.test_client]
|
||||
[mypy-tests.types.*]
|
||||
disallow_untyped_defs = True
|
||||
|
||||
[mypy-tests.util.caches.*]
|
||||
|
||||
@@ -21,14 +21,25 @@ name = "synapse.synapse_rust"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.63"
|
||||
env_logger = "0.10.0"
|
||||
futures = "0.3.25"
|
||||
futures-util = "0.3.25"
|
||||
http = "0.2.8"
|
||||
hyper = { version = "0.14.23", features = ["client", "http1", "http2", "runtime", "server", "full"] }
|
||||
hyper-tls = "0.5.0"
|
||||
lazy_static = "1.4.0"
|
||||
log = "0.4.17"
|
||||
native-tls = "0.2.11"
|
||||
pyo3 = { version = "0.17.1", features = ["extension-module", "macros", "anyhow", "abi3", "abi3-py37"] }
|
||||
pyo3-asyncio = { version = "0.17.0", features = ["tokio", "tokio-runtime"] }
|
||||
pyo3-log = "0.7.0"
|
||||
pythonize = "0.17.0"
|
||||
regex = "1.6.0"
|
||||
serde = { version = "1.0.144", features = ["derive"] }
|
||||
serde_json = "1.0.85"
|
||||
tokio = "1.23.0"
|
||||
tokio-native-tls = "0.3.0"
|
||||
trust-dns-resolver = "0.22.0"
|
||||
|
||||
[build-dependencies]
|
||||
blake2 = "0.10.4"
|
||||
|
||||
158
rust/src/http/mod.rs
Normal file
158
rust/src/http/mod.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::Error;
|
||||
use http::{Request, Uri};
|
||||
use hyper::Body;
|
||||
use log::info;
|
||||
use pyo3::{
|
||||
pyclass, pymethods,
|
||||
types::{PyBytes, PyModule},
|
||||
IntoPy, PyAny, PyObject, PyResult, Python, ToPyObject,
|
||||
};
|
||||
|
||||
use self::resolver::{MatrixConnector, MatrixResolver};
|
||||
|
||||
pub mod resolver;
|
||||
|
||||
/// Called when registering modules with python.
|
||||
pub fn register_module(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
let child_module = PyModule::new(py, "http")?;
|
||||
child_module.add_class::<HttpClient>()?;
|
||||
child_module.add_class::<MatrixResponse>()?;
|
||||
|
||||
m.add_submodule(child_module)?;
|
||||
|
||||
// We need to manually add the module to sys.modules to make `from
|
||||
// synapse.synapse_rust import push` work.
|
||||
py.import("sys")?
|
||||
.getattr("modules")?
|
||||
.set_item("synapse.synapse_rust.http", child_module)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Bytes(pub Vec<u8>);
|
||||
|
||||
impl ToPyObject for Bytes {
|
||||
fn to_object(&self, py: Python<'_>) -> pyo3::PyObject {
|
||||
PyBytes::new(py, &self.0).into_py(py)
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoPy<PyObject> for Bytes {
|
||||
fn into_py(self, py: Python<'_>) -> PyObject {
|
||||
self.to_object(py)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[pyclass]
|
||||
pub struct MatrixResponse {
|
||||
#[pyo3(get)]
|
||||
pub code: u16,
|
||||
#[pyo3(get)]
|
||||
pub phrase: &'static str,
|
||||
#[pyo3(get)]
|
||||
pub content: Bytes,
|
||||
#[pyo3(get)]
|
||||
pub headers: HashMap<String, Bytes>,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Clone)]
|
||||
pub struct HttpClient {
|
||||
client: hyper::Client<MatrixConnector>,
|
||||
resolver: MatrixResolver,
|
||||
}
|
||||
|
||||
impl HttpClient {
|
||||
pub fn new() -> Result<Self, Error> {
|
||||
let resolver = MatrixResolver::new()?;
|
||||
|
||||
let client =
|
||||
hyper::Client::builder().build(MatrixConnector::with_resolver(resolver.clone()));
|
||||
|
||||
Ok(HttpClient { client, resolver })
|
||||
}
|
||||
|
||||
pub async fn async_request(
|
||||
&self,
|
||||
url: String,
|
||||
method: String,
|
||||
headers: HashMap<Vec<u8>, Vec<Vec<u8>>>,
|
||||
body: Option<Vec<u8>>,
|
||||
) -> Result<MatrixResponse, Error> {
|
||||
let uri: Uri = url.try_into()?;
|
||||
|
||||
let mut builder = Request::builder().method(&*method).uri(uri.clone());
|
||||
|
||||
for (key, values) in headers {
|
||||
for value in values {
|
||||
builder = builder.header(key.clone(), value);
|
||||
}
|
||||
}
|
||||
|
||||
if uri.scheme_str() == Some("matrix") {
|
||||
let endpoints = self.resolver.resolve_server_name_from_uri(&uri).await?;
|
||||
if let Some(endpoint) = endpoints.first() {
|
||||
builder = builder.header("Host", &endpoint.host_header);
|
||||
}
|
||||
}
|
||||
|
||||
let request = if let Some(body) = body {
|
||||
builder.body(Body::from(body))?
|
||||
} else {
|
||||
builder.body(Body::empty())?
|
||||
};
|
||||
|
||||
let response = self.client.request(request).await?;
|
||||
|
||||
let code = response.status().as_u16();
|
||||
let phrase = response.status().canonical_reason().unwrap_or_default();
|
||||
|
||||
let headers = response
|
||||
.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), Bytes(v.as_bytes().to_owned())))
|
||||
.collect();
|
||||
|
||||
let body = response.into_body();
|
||||
|
||||
let bytes = hyper::body::to_bytes(body).await?;
|
||||
let content = Bytes(bytes.to_vec());
|
||||
|
||||
Ok(MatrixResponse {
|
||||
code,
|
||||
phrase,
|
||||
content,
|
||||
headers,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl HttpClient {
|
||||
#[new]
|
||||
fn py_new() -> Result<Self, Error> {
|
||||
Self::new()
|
||||
}
|
||||
|
||||
fn request<'a>(
|
||||
&'a self,
|
||||
py: Python<'a>,
|
||||
url: String,
|
||||
method: String,
|
||||
headers: HashMap<Vec<u8>, Vec<Vec<u8>>>,
|
||||
body: Option<Vec<u8>>,
|
||||
) -> PyResult<&'a PyAny> {
|
||||
pyo3::prepare_freethreaded_python();
|
||||
|
||||
let client = self.clone();
|
||||
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
let resp = client.async_request(url, method, headers, body).await?;
|
||||
Ok(resp)
|
||||
})
|
||||
}
|
||||
}
|
||||
432
rust/src/http/resolver.rs
Normal file
432
rust/src/http/resolver.rs
Normal file
@@ -0,0 +1,432 @@
|
||||
use std::collections::BTreeMap;
|
||||
use std::future::Future;
|
||||
use std::net::IpAddr;
|
||||
use std::pin::Pin;
|
||||
use std::str::FromStr;
|
||||
use std::{
|
||||
io::Cursor,
|
||||
sync::{Arc, Mutex},
|
||||
task::{self, Poll},
|
||||
};
|
||||
|
||||
use anyhow::{bail, Error};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use futures_util::stream::StreamExt;
|
||||
use http::Uri;
|
||||
use hyper::client::connect::Connection;
|
||||
use hyper::client::connect::{Connected, HttpConnector};
|
||||
use hyper::server::conn::Http;
|
||||
use hyper::service::Service;
|
||||
use hyper::Client;
|
||||
use hyper_tls::HttpsConnector;
|
||||
use hyper_tls::MaybeHttpsStream;
|
||||
use log::{debug, info};
|
||||
use native_tls::TlsConnector;
|
||||
use serde::Deserialize;
|
||||
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_native_tls::TlsConnector as AsyncTlsConnector;
|
||||
use trust_dns_resolver::error::ResolveErrorKind;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Endpoint {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
|
||||
pub host_header: String,
|
||||
pub tls_name: String,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MatrixResolver {
|
||||
resolver: trust_dns_resolver::TokioAsyncResolver,
|
||||
http_client: Client<HttpsConnector<HttpConnector>>,
|
||||
}
|
||||
|
||||
impl MatrixResolver {
|
||||
pub fn new() -> Result<MatrixResolver, Error> {
|
||||
let http_client = hyper::Client::builder().build(HttpsConnector::new());
|
||||
|
||||
MatrixResolver::with_client(http_client)
|
||||
}
|
||||
|
||||
pub fn with_client(
|
||||
http_client: Client<HttpsConnector<HttpConnector>>,
|
||||
) -> Result<MatrixResolver, Error> {
|
||||
let resolver = trust_dns_resolver::TokioAsyncResolver::tokio_from_system_conf()?;
|
||||
|
||||
Ok(MatrixResolver {
|
||||
resolver,
|
||||
http_client,
|
||||
})
|
||||
}
|
||||
|
||||
/// Does SRV lookup
|
||||
pub async fn resolve_server_name_from_uri(&self, uri: &Uri) -> Result<Vec<Endpoint>, Error> {
|
||||
let host = uri.host().expect("URI has no host").to_string();
|
||||
let port = uri.port_u16();
|
||||
|
||||
self.resolve_server_name_from_host_port(host, port).await
|
||||
}
|
||||
|
||||
pub async fn resolve_server_name_from_host_port(
|
||||
&self,
|
||||
mut host: String,
|
||||
mut port: Option<u16>,
|
||||
) -> Result<Vec<Endpoint>, Error> {
|
||||
let mut authority = if let Some(p) = port {
|
||||
format!("{}:{}", host, p)
|
||||
} else {
|
||||
host.to_string()
|
||||
};
|
||||
|
||||
// If a literal IP or includes port then we shortcircuit.
|
||||
if host.parse::<IpAddr>().is_ok() || port.is_some() {
|
||||
return Ok(vec![Endpoint {
|
||||
host: host.to_string(),
|
||||
port: port.unwrap_or(8448),
|
||||
|
||||
host_header: authority.to_string(),
|
||||
tls_name: host.to_string(),
|
||||
}]);
|
||||
}
|
||||
|
||||
// Do well-known delegation lookup.
|
||||
if let Some(server) = get_well_known(&self.http_client, &host).await {
|
||||
let a = http::uri::Authority::from_str(&server.server)?;
|
||||
host = a.host().to_string();
|
||||
port = a.port_u16();
|
||||
authority = a.to_string();
|
||||
}
|
||||
|
||||
// If a literal IP or includes port then we shortcircuit.
|
||||
if host.parse::<IpAddr>().is_ok() || port.is_some() {
|
||||
return Ok(vec![Endpoint {
|
||||
host: host.clone(),
|
||||
port: port.unwrap_or(8448),
|
||||
|
||||
host_header: authority.to_string(),
|
||||
tls_name: host.clone(),
|
||||
}]);
|
||||
}
|
||||
|
||||
let result = self
|
||||
.resolver
|
||||
.srv_lookup(format!("_matrix._tcp.{}", host))
|
||||
.await;
|
||||
|
||||
let records = match result {
|
||||
Ok(records) => records,
|
||||
Err(err) => match err.kind() {
|
||||
ResolveErrorKind::NoRecordsFound { .. } => {
|
||||
return Ok(vec![Endpoint {
|
||||
host: host.clone(),
|
||||
port: 8448,
|
||||
host_header: authority.to_string(),
|
||||
tls_name: host.clone(),
|
||||
}])
|
||||
}
|
||||
_ => return Err(err.into()),
|
||||
},
|
||||
};
|
||||
|
||||
let mut priority_map: BTreeMap<u16, Vec<_>> = BTreeMap::new();
|
||||
|
||||
let mut count = 0;
|
||||
for record in records {
|
||||
count += 1;
|
||||
let priority = record.priority();
|
||||
priority_map.entry(priority).or_default().push(record);
|
||||
}
|
||||
|
||||
let mut results = Vec::with_capacity(count);
|
||||
|
||||
for (_priority, records) in priority_map {
|
||||
// TODO: Correctly shuffle records
|
||||
results.extend(records.into_iter().map(|record| Endpoint {
|
||||
host: record.target().to_utf8(),
|
||||
port: record.port(),
|
||||
|
||||
host_header: host.to_string(),
|
||||
tls_name: host.to_string(),
|
||||
}))
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_well_known<C>(http_client: &Client<C>, host: &str) -> Option<WellKnownServer>
|
||||
where
|
||||
C: Service<Uri> + Clone + Sync + Send + 'static,
|
||||
C::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
|
||||
C::Future: Unpin + Send,
|
||||
C::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static,
|
||||
{
|
||||
// TODO: Add timeout.
|
||||
|
||||
let uri = hyper::Uri::builder()
|
||||
.scheme("https")
|
||||
.authority(host)
|
||||
.path_and_query("/.well-known/matrix/server")
|
||||
.build()
|
||||
.ok()?;
|
||||
|
||||
let mut body = http_client.get(uri).await.ok()?.into_body();
|
||||
|
||||
let mut vec = Vec::new();
|
||||
while let Some(next) = body.next().await {
|
||||
let chunk = next.ok()?;
|
||||
vec.extend(chunk);
|
||||
}
|
||||
|
||||
serde_json::from_slice(&vec).ok()?
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct WellKnownServer {
|
||||
#[serde(rename = "m.server")]
|
||||
server: String,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MatrixConnector {
|
||||
resolver: MatrixResolver,
|
||||
}
|
||||
|
||||
impl MatrixConnector {
|
||||
pub fn with_resolver(resolver: MatrixResolver) -> MatrixConnector {
|
||||
MatrixConnector { resolver }
|
||||
}
|
||||
}
|
||||
|
||||
impl Service<Uri> for MatrixConnector {
|
||||
type Response = MaybeHttpsStream<TcpStream>;
|
||||
type Error = Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
// This connector is always ready, but others might not be.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, dst: Uri) -> Self::Future {
|
||||
let resolver = self.resolver.clone();
|
||||
|
||||
if dst.scheme_str() != Some("matrix") {
|
||||
debug!("Got non-matrix scheme");
|
||||
return HttpsConnector::new()
|
||||
.call(dst)
|
||||
.map_err(|e| Error::msg(e))
|
||||
.boxed();
|
||||
}
|
||||
|
||||
async move {
|
||||
let endpoints = resolver
|
||||
.resolve_server_name_from_host_port(
|
||||
dst.host().expect("hostname").to_string(),
|
||||
dst.port_u16(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
debug!("Got endpoints: {:?}", endpoints);
|
||||
|
||||
for endpoint in endpoints {
|
||||
match try_connecting(&dst, &endpoint).await {
|
||||
Ok(r) => return Ok(r),
|
||||
// Errors here are not unexpected, and we just move on
|
||||
// with our lives.
|
||||
Err(e) => info!(
|
||||
"Failed to connect to {} via {}:{} because {}",
|
||||
dst.host().expect("hostname"),
|
||||
endpoint.host,
|
||||
endpoint.port,
|
||||
e,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
bail!(
|
||||
"failed to resolve host: {:?} port {:?}",
|
||||
dst.host(),
|
||||
dst.port()
|
||||
)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempts to connect to a particular endpoint.
|
||||
async fn try_connecting(
|
||||
dst: &Uri,
|
||||
endpoint: &Endpoint,
|
||||
) -> Result<MaybeHttpsStream<TcpStream>, Error> {
|
||||
let tcp = TcpStream::connect((&endpoint.host as &str, endpoint.port)).await?;
|
||||
|
||||
let connector: AsyncTlsConnector = if dst.host().expect("hostname").contains("localhost") {
|
||||
TlsConnector::builder()
|
||||
.danger_accept_invalid_certs(true)
|
||||
.build()?
|
||||
.into()
|
||||
} else {
|
||||
TlsConnector::new().unwrap().into()
|
||||
};
|
||||
|
||||
let tls = connector.connect(&endpoint.tls_name, tcp).await?;
|
||||
|
||||
Ok(tls.into())
|
||||
}
|
||||
|
||||
/// A connector that reutrns a connection which returns 200 OK to all connections.
|
||||
#[derive(Clone)]
|
||||
pub struct TestConnector;
|
||||
|
||||
impl Service<Uri> for TestConnector {
|
||||
type Response = TestConnection;
|
||||
type Error = Error;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
|
||||
|
||||
fn poll_ready(&mut self, _: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
// This connector is always ready, but others might not be.
|
||||
Poll::Ready(Ok(()))
|
||||
}
|
||||
|
||||
fn call(&mut self, _dst: Uri) -> Self::Future {
|
||||
let (client, server) = TestConnection::double_ended();
|
||||
|
||||
{
|
||||
let service = hyper::service::service_fn(|_| async move {
|
||||
Ok(hyper::Response::new(hyper::Body::from("Hello World")))
|
||||
as Result<_, hyper::http::Error>
|
||||
});
|
||||
let fut = Http::new().serve_connection(server, service);
|
||||
tokio::spawn(fut);
|
||||
}
|
||||
|
||||
futures::future::ok(client).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct TestConnectionInner {
|
||||
outbound_buffer: Cursor<Vec<u8>>,
|
||||
inbound_buffer: Cursor<Vec<u8>>,
|
||||
wakers: Vec<futures::task::Waker>,
|
||||
}
|
||||
|
||||
/// A in memory connection for use with tests.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct TestConnection {
|
||||
inner: Arc<Mutex<TestConnectionInner>>,
|
||||
direction: bool,
|
||||
}
|
||||
|
||||
impl TestConnection {
|
||||
pub fn double_ended() -> (TestConnection, TestConnection) {
|
||||
let inner: Arc<Mutex<TestConnectionInner>> = Arc::default();
|
||||
|
||||
let a = TestConnection {
|
||||
inner: inner.clone(),
|
||||
direction: false,
|
||||
};
|
||||
|
||||
let b = TestConnection {
|
||||
inner,
|
||||
direction: true,
|
||||
};
|
||||
|
||||
(a, b)
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for TestConnection {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
buf: &mut ReadBuf<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let mut conn = self.inner.lock().expect("mutex");
|
||||
|
||||
let buffer = if self.direction {
|
||||
&mut conn.inbound_buffer
|
||||
} else {
|
||||
&mut conn.outbound_buffer
|
||||
};
|
||||
|
||||
let bytes_read = std::io::Read::read(buffer, buf.initialize_unfilled())?;
|
||||
buf.advance(bytes_read);
|
||||
if bytes_read > 0 {
|
||||
Poll::Ready(Ok(()))
|
||||
} else {
|
||||
conn.wakers.push(cx.waker().clone());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for TestConnection {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
_cx: &mut task::Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
let mut conn = self.inner.lock().expect("mutex");
|
||||
|
||||
if self.direction {
|
||||
conn.outbound_buffer.get_mut().extend_from_slice(buf);
|
||||
} else {
|
||||
conn.inbound_buffer.get_mut().extend_from_slice(buf);
|
||||
}
|
||||
|
||||
for waker in conn.wakers.drain(..) {
|
||||
waker.wake()
|
||||
}
|
||||
|
||||
Poll::Ready(Ok(buf.len()))
|
||||
}
|
||||
fn poll_flush(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let mut conn = self.inner.lock().expect("mutex");
|
||||
|
||||
if self.direction {
|
||||
Pin::new(&mut conn.outbound_buffer).poll_flush(cx)
|
||||
} else {
|
||||
Pin::new(&mut conn.inbound_buffer).poll_flush(cx)
|
||||
}
|
||||
}
|
||||
fn poll_shutdown(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut task::Context<'_>,
|
||||
) -> Poll<Result<(), std::io::Error>> {
|
||||
let mut conn = self.inner.lock().expect("mutex");
|
||||
|
||||
if self.direction {
|
||||
Pin::new(&mut conn.outbound_buffer).poll_shutdown(cx)
|
||||
} else {
|
||||
Pin::new(&mut conn.inbound_buffer).poll_shutdown(cx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Connection for TestConnection {
|
||||
fn connected(&self) -> Connected {
|
||||
Connected::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_memory_connection() {
|
||||
let client: hyper::Client<_, hyper::Body> = hyper::Client::builder().build(TestConnector);
|
||||
|
||||
let response = client
|
||||
.get("http://localhost".parse().unwrap())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(response.status().is_success());
|
||||
|
||||
let bytes = hyper::body::to_bytes(response.into_body()).await.unwrap();
|
||||
assert_eq!(&bytes[..], b"Hello World");
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
use pyo3::prelude::*;
|
||||
|
||||
pub mod http;
|
||||
pub mod push;
|
||||
|
||||
/// Returns the hash of all the rust source files at the time it was compiled.
|
||||
@@ -26,6 +27,7 @@ fn synapse_rust(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(get_rust_file_digest, m)?)?;
|
||||
|
||||
push::register_module(py, m)?;
|
||||
http::register_module(py, m)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ import time
|
||||
import urllib.request
|
||||
from os import path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Any, List, Optional, cast
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import attr
|
||||
import click
|
||||
@@ -174,9 +174,7 @@ def _prepare() -> None:
|
||||
click.get_current_context().abort()
|
||||
|
||||
# Switch to the release branch.
|
||||
# Cast safety: parse() won't return a version.LegacyVersion from our
|
||||
# version string format.
|
||||
parsed_new_version = cast(version.Version, version.parse(new_version))
|
||||
parsed_new_version = version.parse(new_version)
|
||||
|
||||
# We assume for debian changelogs that we only do RCs or full releases.
|
||||
assert not parsed_new_version.is_devrelease
|
||||
|
||||
16
stubs/synapse/synapse_rust/http.pyi
Normal file
16
stubs/synapse/synapse_rust/http.pyi
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
class MatrixResponse:
|
||||
code: int
|
||||
phrase: str
|
||||
content: bytes
|
||||
headers: Dict[str, str]
|
||||
|
||||
class HttpClient:
|
||||
async def request(
|
||||
self,
|
||||
url: str,
|
||||
method: str,
|
||||
headers: Dict[bytes, List[bytes]],
|
||||
body: Optional[bytes],
|
||||
) -> MatrixResponse: ...
|
||||
@@ -29,7 +29,7 @@ if sys.version_info < (3, 7):
|
||||
sys.exit(1)
|
||||
|
||||
# Allow using the asyncio reactor via env var.
|
||||
if strtobool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", "0")):
|
||||
if strtobool(os.environ.get("SYNAPSE_ASYNC_IO_REACTOR", "0")) or True:
|
||||
from incremental import Version
|
||||
|
||||
import twisted
|
||||
|
||||
@@ -245,7 +245,9 @@ class ApplicationService:
|
||||
return True
|
||||
|
||||
# likewise with the room's aliases (if it has any)
|
||||
alias_list = await store.get_aliases_for_room(room_id)
|
||||
alias_list = await store.get_aliases_for_room(
|
||||
room_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
for alias in alias_list:
|
||||
if self.is_room_alias_in_namespace(alias):
|
||||
return True
|
||||
@@ -311,7 +313,9 @@ class ApplicationService:
|
||||
# Find all the rooms the sender is in
|
||||
if self.is_interested_in_user(user_id.to_string()):
|
||||
return True
|
||||
room_ids = await store.get_rooms_for_user(user_id.to_string())
|
||||
room_ids = await store.get_rooms_for_user(
|
||||
user_id.to_string(), on_invalidate=cache_context.invalidate
|
||||
)
|
||||
|
||||
# Then find out if the appservice is interested in any of those rooms
|
||||
for room_id in room_ids:
|
||||
|
||||
@@ -33,6 +33,9 @@ def validate_config(
|
||||
config: the configuration value to be validated
|
||||
config_path: the path within the config file. This will be used as a basis
|
||||
for the error message.
|
||||
|
||||
Raises:
|
||||
ConfigError, if validation fails.
|
||||
"""
|
||||
try:
|
||||
jsonschema.validate(config, json_schema)
|
||||
|
||||
@@ -13,12 +13,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Iterable
|
||||
from typing import Any, Iterable, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.config._base import Config, ConfigError
|
||||
from synapse.config._util import validate_config
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types.state import StateFilter
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,16 +27,20 @@ logger = logging.getLogger(__name__)
|
||||
class ApiConfig(Config):
|
||||
section = "api"
|
||||
|
||||
room_prejoin_state: StateFilter
|
||||
track_puppetted_users_ips: bool
|
||||
|
||||
def read_config(self, config: JsonDict, **kwargs: Any) -> None:
|
||||
validate_config(_MAIN_SCHEMA, config, ())
|
||||
self.room_prejoin_state = list(self._get_prejoin_state_types(config))
|
||||
self.room_prejoin_state = StateFilter.from_types(
|
||||
self._get_prejoin_state_entries(config)
|
||||
)
|
||||
self.track_puppeted_user_ips = config.get("track_puppeted_user_ips", False)
|
||||
|
||||
def _get_prejoin_state_types(self, config: JsonDict) -> Iterable[str]:
|
||||
"""Get the event types to include in the prejoin state
|
||||
|
||||
Parses the config and returns an iterable of the event types to be included.
|
||||
"""
|
||||
def _get_prejoin_state_entries(
|
||||
self, config: JsonDict
|
||||
) -> Iterable[Tuple[str, Optional[str]]]:
|
||||
"""Get the event types and state keys to include in the prejoin state."""
|
||||
room_prejoin_state_config = config.get("room_prejoin_state") or {}
|
||||
|
||||
# backwards-compatibility support for room_invite_state_types
|
||||
@@ -50,33 +55,39 @@ class ApiConfig(Config):
|
||||
|
||||
logger.warning(_ROOM_INVITE_STATE_TYPES_WARNING)
|
||||
|
||||
yield from config["room_invite_state_types"]
|
||||
for event_type in config["room_invite_state_types"]:
|
||||
yield event_type, None
|
||||
return
|
||||
|
||||
if not room_prejoin_state_config.get("disable_default_event_types"):
|
||||
yield from _DEFAULT_PREJOIN_STATE_TYPES
|
||||
yield from _DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS
|
||||
|
||||
yield from room_prejoin_state_config.get("additional_event_types", [])
|
||||
for entry in room_prejoin_state_config.get("additional_event_types", []):
|
||||
if isinstance(entry, str):
|
||||
yield entry, None
|
||||
else:
|
||||
yield entry
|
||||
|
||||
|
||||
_ROOM_INVITE_STATE_TYPES_WARNING = """\
|
||||
WARNING: The 'room_invite_state_types' configuration setting is now deprecated,
|
||||
and replaced with 'room_prejoin_state'. New features may not work correctly
|
||||
unless 'room_invite_state_types' is removed. See the sample configuration file for
|
||||
details of 'room_prejoin_state'.
|
||||
unless 'room_invite_state_types' is removed. See the config documentation at
|
||||
https://matrix-org.github.io/synapse/latest/usage/configuration/config_documentation.html#room_prejoin_state
|
||||
for details of 'room_prejoin_state'.
|
||||
--------------------------------------------------------------------------------
|
||||
"""
|
||||
|
||||
_DEFAULT_PREJOIN_STATE_TYPES = [
|
||||
EventTypes.JoinRules,
|
||||
EventTypes.CanonicalAlias,
|
||||
EventTypes.RoomAvatar,
|
||||
EventTypes.RoomEncryption,
|
||||
EventTypes.Name,
|
||||
_DEFAULT_PREJOIN_STATE_TYPES_AND_STATE_KEYS = [
|
||||
(EventTypes.JoinRules, ""),
|
||||
(EventTypes.CanonicalAlias, ""),
|
||||
(EventTypes.RoomAvatar, ""),
|
||||
(EventTypes.RoomEncryption, ""),
|
||||
(EventTypes.Name, ""),
|
||||
# Per MSC1772.
|
||||
EventTypes.Create,
|
||||
(EventTypes.Create, ""),
|
||||
# Per MSC3173.
|
||||
EventTypes.Topic,
|
||||
(EventTypes.Topic, ""),
|
||||
]
|
||||
|
||||
|
||||
@@ -90,7 +101,17 @@ _ROOM_PREJOIN_STATE_CONFIG_SCHEMA = {
|
||||
"disable_default_event_types": {"type": "boolean"},
|
||||
"additional_event_types": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"items": {
|
||||
"oneOf": [
|
||||
{"type": "string"},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 2,
|
||||
"maxItems": 2,
|
||||
},
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -28,8 +28,14 @@ from typing import (
|
||||
)
|
||||
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
|
||||
from synapse.api.constants import (
|
||||
MAX_PDU_SIZE,
|
||||
EventContentFields,
|
||||
EventTypes,
|
||||
RelationTypes,
|
||||
)
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.room_versions import RoomVersion
|
||||
from synapse.types import JsonDict
|
||||
@@ -674,3 +680,27 @@ def validate_canonicaljson(value: Any) -> None:
|
||||
elif not isinstance(value, (bool, str)) and value is not None:
|
||||
# Other potential JSON values (bool, None, str) are safe.
|
||||
raise SynapseError(400, "Unknown JSON value", Codes.BAD_JSON)
|
||||
|
||||
|
||||
def maybe_upsert_event_field(
|
||||
event: EventBase, container: JsonDict, key: str, value: object
|
||||
) -> bool:
|
||||
"""Upsert an event field, but only if this doesn't make the event too large.
|
||||
|
||||
Returns true iff the upsert took place.
|
||||
"""
|
||||
if key in container:
|
||||
old_value: object = container[key]
|
||||
container[key] = value
|
||||
# NB: here and below, we assume that passing a non-None `time_now` argument to
|
||||
# get_pdu_json doesn't increase the size of the encoded result.
|
||||
upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE
|
||||
if not upsert_okay:
|
||||
container[key] = old_value
|
||||
else:
|
||||
container[key] = value
|
||||
upsert_okay = len(encode_canonical_json(event.get_pdu_json())) <= MAX_PDU_SIZE
|
||||
if not upsert_okay:
|
||||
del container[key]
|
||||
|
||||
return upsert_okay
|
||||
|
||||
@@ -50,6 +50,7 @@ from synapse.event_auth import validate_event_for_room_version
|
||||
from synapse.events import EventBase, relation_from_event
|
||||
from synapse.events.builder import EventBuilder
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.events.utils import maybe_upsert_event_field
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.handlers.directory import DirectoryHandler
|
||||
from synapse.logging import opentracing
|
||||
@@ -1739,12 +1740,15 @@ class EventCreationHandler:
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
if event.content["membership"] == Membership.INVITE:
|
||||
event.unsigned[
|
||||
"invite_room_state"
|
||||
] = await self.store.get_stripped_room_state_from_event_context(
|
||||
context,
|
||||
self.room_prejoin_state_types,
|
||||
membership_user_id=event.sender,
|
||||
maybe_upsert_event_field(
|
||||
event,
|
||||
event.unsigned,
|
||||
"invite_room_state",
|
||||
await self.store.get_stripped_room_state_from_event_context(
|
||||
context,
|
||||
self.room_prejoin_state_types,
|
||||
membership_user_id=event.sender,
|
||||
),
|
||||
)
|
||||
|
||||
invitee = UserID.from_string(event.state_key)
|
||||
@@ -1762,11 +1766,14 @@ class EventCreationHandler:
|
||||
event.signatures.update(returned_invite.signatures)
|
||||
|
||||
if event.content["membership"] == Membership.KNOCK:
|
||||
event.unsigned[
|
||||
"knock_room_state"
|
||||
] = await self.store.get_stripped_room_state_from_event_context(
|
||||
context,
|
||||
self.room_prejoin_state_types,
|
||||
maybe_upsert_event_field(
|
||||
event,
|
||||
event.unsigned,
|
||||
"knock_room_state",
|
||||
await self.store.get_stripped_room_state_from_event_context(
|
||||
context,
|
||||
self.room_prejoin_state_types,
|
||||
),
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Redaction:
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import asyncio
|
||||
import cgi
|
||||
import codecs
|
||||
import logging
|
||||
@@ -42,14 +43,18 @@ from canonicaljson import encode_canonical_json
|
||||
from prometheus_client import Counter
|
||||
from signedjson.sign import sign_json
|
||||
from typing_extensions import Literal
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.internet.interfaces import IReactorTime
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.internet.task import Cooperator
|
||||
from twisted.web.client import ResponseFailed
|
||||
from twisted.internet.testing import StringTransport
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.client import Response, ResponseDone, ResponseFailed
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import IBodyProducer, IResponse
|
||||
from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer, IResponse
|
||||
|
||||
import synapse.metrics
|
||||
import synapse.util.retryutils
|
||||
@@ -75,6 +80,7 @@ from synapse.http.types import QueryParams
|
||||
from synapse.logging import opentracing
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||
from synapse.synapse_rust.http import HttpClient
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred
|
||||
@@ -199,6 +205,33 @@ class JsonParser(ByteParser[Union[JsonDict, list]]):
|
||||
return json_decoder.decode(self._buffer.getvalue())
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
@implementer(IResponse)
|
||||
class RustResponse:
|
||||
version: tuple
|
||||
|
||||
code: int
|
||||
|
||||
phrase: bytes
|
||||
|
||||
headers: Headers
|
||||
|
||||
length: Union[int, UNKNOWN_LENGTH]
|
||||
|
||||
# request: Optional[IClientRequest]
|
||||
|
||||
# previousResponse: Optional[IResponse]
|
||||
|
||||
_data: bytes
|
||||
|
||||
def deliverBody(self, protocol: Protocol):
|
||||
protocol.dataReceived(self._data)
|
||||
protocol.connectionLost(Failure(ResponseDone("Response body fully received")))
|
||||
|
||||
def setPreviousResponse(self, response: IResponse):
|
||||
pass
|
||||
|
||||
|
||||
async def _handle_response(
|
||||
reactor: IReactorTime,
|
||||
timeout_sec: float,
|
||||
@@ -372,6 +405,8 @@ class MatrixFederationHttpClient:
|
||||
|
||||
self._sleeper = AwakenableSleeper(self.reactor)
|
||||
|
||||
self._rust_client = HttpClient()
|
||||
|
||||
def wake_destination(self, destination: str) -> None:
|
||||
"""Called when the remote server may have come back online."""
|
||||
|
||||
@@ -556,11 +591,8 @@ class MatrixFederationHttpClient:
|
||||
destination_bytes, method_bytes, url_to_sign_bytes, json
|
||||
)
|
||||
data = encode_canonical_json(json)
|
||||
producer: Optional[IBodyProducer] = QuieterFileBodyProducer(
|
||||
BytesIO(data), cooperator=self._cooperator
|
||||
)
|
||||
else:
|
||||
producer = None
|
||||
data = None
|
||||
auth_headers = self.build_auth_headers(
|
||||
destination_bytes, method_bytes, url_to_sign_bytes
|
||||
)
|
||||
@@ -591,23 +623,33 @@ class MatrixFederationHttpClient:
|
||||
# * The `Deferred` that joins the forks back together is
|
||||
# wrapped in `make_deferred_yieldable` to restore the
|
||||
# logging context regardless of the path taken.
|
||||
request_deferred = run_in_background(
|
||||
self.agent.request,
|
||||
method_bytes,
|
||||
url_bytes,
|
||||
headers=Headers(headers_dict),
|
||||
bodyProducer=producer,
|
||||
)
|
||||
request_deferred = timeout_deferred(
|
||||
request_deferred,
|
||||
timeout=_sec_timeout,
|
||||
reactor=self.reactor,
|
||||
)
|
||||
# request_deferred = run_in_background(
|
||||
# self._rust_client.request,
|
||||
# url_str,
|
||||
# request.method,
|
||||
# headers_dict,
|
||||
# data,
|
||||
# )
|
||||
# request_deferred = timeout_deferred(
|
||||
# request_deferred,
|
||||
# timeout=_sec_timeout,
|
||||
# reactor=self.reactor,
|
||||
# )
|
||||
|
||||
response = await make_deferred_yieldable(request_deferred)
|
||||
# response = await make_deferred_yieldable(request_deferred)
|
||||
|
||||
response_d = run_in_background(
|
||||
self._rust_client.request,
|
||||
url_str,
|
||||
request.method,
|
||||
headers_dict,
|
||||
data,
|
||||
)
|
||||
response = await make_deferred_yieldable(response_d)
|
||||
except DNSLookupError as e:
|
||||
raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
|
||||
except Exception as e:
|
||||
logger.exception("ERROR")
|
||||
raise RequestSendFailed(e, can_retry=True) from e
|
||||
|
||||
incoming_responses_counter.labels(
|
||||
@@ -615,7 +657,7 @@ class MatrixFederationHttpClient:
|
||||
).inc()
|
||||
|
||||
set_tag(tags.HTTP_STATUS_CODE, response.code)
|
||||
response_phrase = response.phrase.decode("ascii", errors="replace")
|
||||
response_phrase = response.phrase
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
logger.debug(
|
||||
@@ -635,25 +677,7 @@ class MatrixFederationHttpClient:
|
||||
)
|
||||
# :'(
|
||||
# Update transactions table?
|
||||
d = treq.content(response)
|
||||
d = timeout_deferred(
|
||||
d, timeout=_sec_timeout, reactor=self.reactor
|
||||
)
|
||||
|
||||
try:
|
||||
body = await make_deferred_yieldable(d)
|
||||
except Exception as e:
|
||||
# Eh, we're already going to raise an exception so lets
|
||||
# ignore if this fails.
|
||||
logger.warning(
|
||||
"{%s} [%s] Failed to get error response: %s %s: %s",
|
||||
request.txn_id,
|
||||
request.destination,
|
||||
request.method,
|
||||
url_str,
|
||||
_flatten_response_never_received(e),
|
||||
)
|
||||
body = None
|
||||
body = response.content
|
||||
|
||||
exc = HttpResponseException(
|
||||
response.code, response_phrase, body
|
||||
@@ -715,7 +739,19 @@ class MatrixFederationHttpClient:
|
||||
_flatten_response_never_received(e),
|
||||
)
|
||||
raise
|
||||
return response
|
||||
|
||||
headers = Headers()
|
||||
for key, value in response.headers.items():
|
||||
headers.addRawHeader(key, value)
|
||||
|
||||
return RustResponse(
|
||||
("HTTP", 1, 1),
|
||||
response.code,
|
||||
response.phrase.encode("ascii"),
|
||||
headers,
|
||||
UNKNOWN_LENGTH,
|
||||
response.content,
|
||||
)
|
||||
|
||||
def build_auth_headers(
|
||||
self,
|
||||
|
||||
@@ -26,6 +26,7 @@ import logging
|
||||
import threading
|
||||
import typing
|
||||
import warnings
|
||||
from asyncio import Future
|
||||
from types import TracebackType
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -814,6 +815,8 @@ def run_in_background( # type: ignore[misc]
|
||||
res = defer.ensureDeferred(res)
|
||||
elif isinstance(res, defer.Deferred):
|
||||
pass
|
||||
elif isinstance(res, Future):
|
||||
res = defer.Deferred.fromFuture(res)
|
||||
elif isinstance(res, Awaitable):
|
||||
# `res` is probably some kind of completed awaitable, such as a `DoneAwaitable`
|
||||
# or `Future` from `make_awaitable`.
|
||||
|
||||
@@ -667,7 +667,8 @@ class DatabasePool:
|
||||
)
|
||||
# also check variables referenced in func's closure
|
||||
if inspect.isfunction(func):
|
||||
f = cast(types.FunctionType, func)
|
||||
# Keep the cast for now---it helps PyCharm to understand what `func` is.
|
||||
f = cast(types.FunctionType, func) # type: ignore[redundant-cast]
|
||||
if f.__closure__:
|
||||
for i, cell in enumerate(f.__closure__):
|
||||
if inspect.isgenerator(cell.cell_contents):
|
||||
|
||||
@@ -16,11 +16,11 @@ import logging
|
||||
import threading
|
||||
import weakref
|
||||
from enum import Enum, auto
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Container,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
@@ -76,6 +76,7 @@ from synapse.storage.util.id_generators import (
|
||||
)
|
||||
from synapse.storage.util.sequence import build_sequence_generator
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
@@ -879,7 +880,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
async def get_stripped_room_state_from_event_context(
|
||||
self,
|
||||
context: EventContext,
|
||||
state_types_to_include: Container[str],
|
||||
state_keys_to_include: StateFilter,
|
||||
membership_user_id: Optional[str] = None,
|
||||
) -> List[JsonDict]:
|
||||
"""
|
||||
@@ -892,7 +893,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
Args:
|
||||
context: The event context to retrieve state of the room from.
|
||||
state_types_to_include: The type of state events to include.
|
||||
state_keys_to_include: The state events to include, for each event type.
|
||||
membership_user_id: An optional user ID to include the stripped membership state
|
||||
events of. This is useful when generating the stripped state of a room for
|
||||
invites. We want to send membership events of the inviter, so that the
|
||||
@@ -901,21 +902,25 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
Returns:
|
||||
A list of dictionaries, each representing a stripped state event from the room.
|
||||
"""
|
||||
current_state_ids = await context.get_current_state_ids()
|
||||
if membership_user_id:
|
||||
types = chain(
|
||||
state_keys_to_include.to_types(),
|
||||
[(EventTypes.Member, membership_user_id)],
|
||||
)
|
||||
filter = StateFilter.from_types(types)
|
||||
else:
|
||||
filter = state_keys_to_include
|
||||
selected_state_ids = await context.get_current_state_ids(filter)
|
||||
|
||||
# We know this event is not an outlier, so this must be
|
||||
# non-None.
|
||||
assert current_state_ids is not None
|
||||
assert selected_state_ids is not None
|
||||
|
||||
# The state to include
|
||||
state_to_include_ids = [
|
||||
e_id
|
||||
for k, e_id in current_state_ids.items()
|
||||
if k[0] in state_types_to_include
|
||||
or (membership_user_id and k == (EventTypes.Member, membership_user_id))
|
||||
]
|
||||
# Confusingly, get_current_state_events may return events that are discarded by
|
||||
# the filter, if they're in context._state_delta_due_to_event. Strip these away.
|
||||
selected_state_ids = filter.filter_state(selected_state_ids)
|
||||
|
||||
state_to_include = await self.get_events(state_to_include_ids)
|
||||
state_to_include = await self.get_events(selected_state_ids.values())
|
||||
|
||||
return [
|
||||
{
|
||||
|
||||
@@ -77,7 +77,7 @@ class PostgresEngine(
|
||||
# docs: The number is formed by converting the major, minor, and
|
||||
# revision numbers into two-decimal-digit numbers and appending them
|
||||
# together. For example, version 8.1.5 will be returned as 80105
|
||||
self._version = cast(int, db_conn.server_version)
|
||||
self._version = db_conn.server_version
|
||||
allow_unsafe_locale = self.config.get("allow_unsafe_locale", False)
|
||||
|
||||
# Are we on a supported PostgreSQL version?
|
||||
|
||||
@@ -118,6 +118,15 @@ class StateFilter:
|
||||
)
|
||||
)
|
||||
|
||||
def to_types(self) -> Iterable[Tuple[str, Optional[str]]]:
|
||||
"""The inverse to `from_types`."""
|
||||
for (event_type, state_keys) in self.types.items():
|
||||
if state_keys is None:
|
||||
yield event_type, None
|
||||
else:
|
||||
for state_key in state_keys:
|
||||
yield event_type, state_key
|
||||
|
||||
@staticmethod
|
||||
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
|
||||
"""Creates a filter that returns all non-member events, plus the member
|
||||
@@ -343,6 +352,15 @@ class StateFilter:
|
||||
for s in state_keys
|
||||
]
|
||||
|
||||
def wildcard_types(self) -> List[str]:
|
||||
"""Returns a list of event types which require us to fetch all state keys.
|
||||
This will be empty unless `has_wildcards` returns True.
|
||||
|
||||
Returns:
|
||||
A list of event types.
|
||||
"""
|
||||
return [t for t, state_keys in self.types.items() if state_keys is None]
|
||||
|
||||
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
|
||||
"""Return the filter split into two: one which assumes it's exclusively
|
||||
matching against member state, and one which assumes it's matching
|
||||
|
||||
145
tests/config/test_api.py
Normal file
145
tests/config/test_api.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from unittest import TestCase as StdlibTestCase
|
||||
|
||||
import yaml
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.config.api import ApiConfig
|
||||
from synapse.types.state import StateFilter
|
||||
|
||||
DEFAULT_PREJOIN_STATE_PAIRS = {
|
||||
("m.room.join_rules", ""),
|
||||
("m.room.canonical_alias", ""),
|
||||
("m.room.avatar", ""),
|
||||
("m.room.encryption", ""),
|
||||
("m.room.name", ""),
|
||||
("m.room.create", ""),
|
||||
("m.room.topic", ""),
|
||||
}
|
||||
|
||||
|
||||
class TestRoomPrejoinState(StdlibTestCase):
|
||||
def read_config(self, source: str) -> ApiConfig:
|
||||
config = ApiConfig()
|
||||
config.read_config(yaml.safe_load(source))
|
||||
return config
|
||||
|
||||
def test_no_prejoin_state(self) -> None:
|
||||
config = self.read_config("foo: bar")
|
||||
self.assertFalse(config.room_prejoin_state.has_wildcards())
|
||||
self.assertEqual(
|
||||
set(config.room_prejoin_state.concrete_types()), DEFAULT_PREJOIN_STATE_PAIRS
|
||||
)
|
||||
|
||||
def test_disable_default_event_types(self) -> None:
|
||||
config = self.read_config(
|
||||
"""
|
||||
room_prejoin_state:
|
||||
disable_default_event_types: true
|
||||
"""
|
||||
)
|
||||
self.assertEqual(config.room_prejoin_state, StateFilter.none())
|
||||
|
||||
def test_event_without_state_key(self) -> None:
|
||||
config = self.read_config(
|
||||
"""
|
||||
room_prejoin_state:
|
||||
disable_default_event_types: true
|
||||
additional_event_types:
|
||||
- foo
|
||||
"""
|
||||
)
|
||||
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
|
||||
self.assertEqual(config.room_prejoin_state.concrete_types(), [])
|
||||
|
||||
def test_event_with_specific_state_key(self) -> None:
|
||||
config = self.read_config(
|
||||
"""
|
||||
room_prejoin_state:
|
||||
disable_default_event_types: true
|
||||
additional_event_types:
|
||||
- [foo, bar]
|
||||
"""
|
||||
)
|
||||
self.assertFalse(config.room_prejoin_state.has_wildcards())
|
||||
self.assertEqual(
|
||||
set(config.room_prejoin_state.concrete_types()),
|
||||
{("foo", "bar")},
|
||||
)
|
||||
|
||||
def test_repeated_event_with_specific_state_key(self) -> None:
|
||||
config = self.read_config(
|
||||
"""
|
||||
room_prejoin_state:
|
||||
disable_default_event_types: true
|
||||
additional_event_types:
|
||||
- [foo, bar]
|
||||
- [foo, baz]
|
||||
"""
|
||||
)
|
||||
self.assertFalse(config.room_prejoin_state.has_wildcards())
|
||||
self.assertEqual(
|
||||
set(config.room_prejoin_state.concrete_types()),
|
||||
{("foo", "bar"), ("foo", "baz")},
|
||||
)
|
||||
|
||||
def test_no_specific_state_key_overrides_specific_state_key(self) -> None:
|
||||
config = self.read_config(
|
||||
"""
|
||||
room_prejoin_state:
|
||||
disable_default_event_types: true
|
||||
additional_event_types:
|
||||
- [foo, bar]
|
||||
- foo
|
||||
"""
|
||||
)
|
||||
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
|
||||
self.assertEqual(config.room_prejoin_state.concrete_types(), [])
|
||||
|
||||
config = self.read_config(
|
||||
"""
|
||||
room_prejoin_state:
|
||||
disable_default_event_types: true
|
||||
additional_event_types:
|
||||
- foo
|
||||
- [foo, bar]
|
||||
"""
|
||||
)
|
||||
self.assertEqual(config.room_prejoin_state.wildcard_types(), ["foo"])
|
||||
self.assertEqual(config.room_prejoin_state.concrete_types(), [])
|
||||
|
||||
def test_bad_event_type_entry_raises(self) -> None:
|
||||
with self.assertRaises(ConfigError):
|
||||
self.read_config(
|
||||
"""
|
||||
room_prejoin_state:
|
||||
additional_event_types:
|
||||
- []
|
||||
"""
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigError):
|
||||
self.read_config(
|
||||
"""
|
||||
room_prejoin_state:
|
||||
additional_event_types:
|
||||
- [a]
|
||||
"""
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigError):
|
||||
self.read_config(
|
||||
"""
|
||||
room_prejoin_state:
|
||||
additional_event_types:
|
||||
- [a, b, c]
|
||||
"""
|
||||
)
|
||||
|
||||
with self.assertRaises(ConfigError):
|
||||
self.read_config(
|
||||
"""
|
||||
room_prejoin_state:
|
||||
additional_event_types:
|
||||
- [true, 1.23]
|
||||
"""
|
||||
)
|
||||
@@ -12,19 +12,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest as stdlib_unittest
|
||||
|
||||
from synapse.api.constants import EventContentFields
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.events.utils import (
|
||||
SerializeEventConfig,
|
||||
copy_and_fixup_power_levels_contents,
|
||||
maybe_upsert_event_field,
|
||||
prune_event,
|
||||
serialize_event,
|
||||
)
|
||||
from synapse.util.frozenutils import freeze
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
def MockEvent(**kwargs):
|
||||
if "event_id" not in kwargs:
|
||||
@@ -34,7 +35,31 @@ def MockEvent(**kwargs):
|
||||
return make_event_from_dict(kwargs)
|
||||
|
||||
|
||||
class PruneEventTestCase(unittest.TestCase):
|
||||
class TestMaybeUpsertEventField(stdlib_unittest.TestCase):
|
||||
def test_update_okay(self) -> None:
|
||||
event = make_event_from_dict({"event_id": "$1234"})
|
||||
success = maybe_upsert_event_field(event, event.unsigned, "key", "value")
|
||||
self.assertTrue(success)
|
||||
self.assertEqual(event.unsigned["key"], "value")
|
||||
|
||||
def test_update_not_okay(self) -> None:
|
||||
event = make_event_from_dict({"event_id": "$1234"})
|
||||
LARGE_STRING = "a" * 100_000
|
||||
success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
|
||||
self.assertFalse(success)
|
||||
self.assertNotIn("key", event.unsigned)
|
||||
|
||||
def test_update_not_okay_leaves_original_value(self) -> None:
|
||||
event = make_event_from_dict(
|
||||
{"event_id": "$1234", "unsigned": {"key": "value"}}
|
||||
)
|
||||
LARGE_STRING = "a" * 100_000
|
||||
success = maybe_upsert_event_field(event, event.unsigned, "key", LARGE_STRING)
|
||||
self.assertFalse(success)
|
||||
self.assertEqual(event.unsigned["key"], "value")
|
||||
|
||||
|
||||
class PruneEventTestCase(stdlib_unittest.TestCase):
|
||||
def run_test(self, evdict, matchdict, **kwargs):
|
||||
"""
|
||||
Asserts that a new event constructed with `evdict` will look like
|
||||
@@ -391,7 +416,7 @@ class PruneEventTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class SerializeEventTestCase(unittest.TestCase):
|
||||
class SerializeEventTestCase(stdlib_unittest.TestCase):
|
||||
def serialize(self, ev, fields):
|
||||
return serialize_event(
|
||||
ev, 1479807801915, config=SerializeEventConfig(only_event_fields=fields)
|
||||
@@ -513,7 +538,7 @@ class SerializeEventTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class CopyPowerLevelsContentTestCase(unittest.TestCase):
|
||||
class CopyPowerLevelsContentTestCase(stdlib_unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.test_content = {
|
||||
"ban": 50,
|
||||
|
||||
@@ -26,7 +26,7 @@ from synapse.types import JsonDict, RoomID, StateMap, UserID
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.unittest import HomeserverTestCase, TestCase
|
||||
from tests.unittest import HomeserverTestCase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -494,624 +494,3 @@ class StateStoreTestCase(HomeserverTestCase):
|
||||
|
||||
self.assertEqual(is_all, True)
|
||||
self.assertDictEqual({(e5.type, e5.state_key): e5.event_id}, state_dict)
|
||||
|
||||
|
||||
class StateFilterDifferenceTestCase(TestCase):
|
||||
def assert_difference(
|
||||
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
|
||||
) -> None:
|
||||
self.assertEqual(
|
||||
minuend.approx_difference(subtrahend),
|
||||
expected,
|
||||
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
|
||||
)
|
||||
|
||||
def test_state_filter_difference_no_include_other_minus_no_include_other(
|
||||
self,
|
||||
) -> None:
|
||||
"""
|
||||
Tests the StateFilter.approx_difference method
|
||||
where, in a.approx_difference(b), both a and b do not have the
|
||||
include_others flag set.
|
||||
"""
|
||||
# (wildcard on state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.Create: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze({EventTypes.Create: None}, include_others=False),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (specific state keys)
|
||||
# This one is an over-approximation because we can't represent
|
||||
# 'all state keys except a few named examples'
|
||||
self.assert_difference(
|
||||
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: {"@wombat:spqr"}},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.CanonicalAlias: {""}},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (specific state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr"},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
|
||||
"""
|
||||
Tests the StateFilter.approx_difference method
|
||||
where, in a.approx_difference(b), only a has the include_others flag set.
|
||||
"""
|
||||
# (wildcard on state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.Create: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Create: None,
|
||||
EventTypes.Member: set(),
|
||||
EventTypes.CanonicalAlias: set(),
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (specific state keys)
|
||||
# This one is an over-approximation because we can't represent
|
||||
# 'all state keys except a few named examples'
|
||||
# This also shows that the resultant state filter is normalised.
|
||||
self.assert_difference(
|
||||
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr"},
|
||||
EventTypes.Create: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter(types=frozendict(), include_others=True),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter(
|
||||
types=frozendict(),
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (specific state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr"},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
def test_state_filter_difference_include_other_minus_include_other(self) -> None:
|
||||
"""
|
||||
Tests the StateFilter.approx_difference method
|
||||
where, in a.approx_difference(b), both a and b have the include_others
|
||||
flag set.
|
||||
"""
|
||||
# (wildcard on state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.Create: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter(types=frozendict(), include_others=False),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (specific state keys)
|
||||
# This one is an over-approximation because we can't represent
|
||||
# 'all state keys except a few named examples'
|
||||
self.assert_difference(
|
||||
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter(
|
||||
types=frozendict(),
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (specific state keys)
|
||||
# This one is an over-approximation because we can't represent
|
||||
# 'all state keys except a few named examples'
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
EventTypes.Create: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr"},
|
||||
EventTypes.Create: set(),
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@spqr:spqr"},
|
||||
EventTypes.Create: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
|
||||
"""
|
||||
Tests the StateFilter.approx_difference method
|
||||
where, in a.approx_difference(b), only b has the include_others flag set.
|
||||
"""
|
||||
# (wildcard on state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.Create: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter(types=frozendict(), include_others=False),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (specific state keys)
|
||||
# This one is an over-approximation because we can't represent
|
||||
# 'all state keys except a few named examples'
|
||||
self.assert_difference(
|
||||
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: {"@wombat:spqr"}},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter(
|
||||
types=frozendict(),
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (specific state keys)
|
||||
# This one is an over-approximation because we can't represent
|
||||
# 'all state keys except a few named examples'
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr"},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@spqr:spqr"},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
def test_state_filter_difference_simple_cases(self) -> None:
|
||||
"""
|
||||
Tests some very simple cases of the StateFilter approx_difference,
|
||||
that are not explicitly tested by the more in-depth tests.
|
||||
"""
|
||||
|
||||
self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
|
||||
|
||||
self.assert_difference(
|
||||
StateFilter.all(),
|
||||
StateFilter.none(),
|
||||
StateFilter.all(),
|
||||
)
|
||||
|
||||
|
||||
class StateFilterTestCase(TestCase):
|
||||
def test_return_expanded(self) -> None:
|
||||
"""
|
||||
Tests the behaviour of the return_expanded() function that expands
|
||||
StateFilters to include more state types (for the sake of cache hit rate).
|
||||
"""
|
||||
|
||||
self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
|
||||
|
||||
self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
|
||||
|
||||
# Concrete-only state filters stay the same
|
||||
# (Case: mixed filter)
|
||||
self.assertEqual(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
||||
"some.other.state.type": {""},
|
||||
},
|
||||
include_others=False,
|
||||
).return_expanded(),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
||||
"some.other.state.type": {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Concrete-only state filters stay the same
|
||||
# (Case: non-member-only filter)
|
||||
self.assertEqual(
|
||||
StateFilter.freeze(
|
||||
{"some.other.state.type": {""}}, include_others=False
|
||||
).return_expanded(),
|
||||
StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
|
||||
)
|
||||
|
||||
# Concrete-only state filters stay the same
|
||||
# (Case: member-only filter)
|
||||
self.assertEqual(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
||||
},
|
||||
include_others=False,
|
||||
).return_expanded(),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Wildcard member-only state filters stay the same
|
||||
self.assertEqual(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
).return_expanded(),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# If there is a wildcard in the non-member portion of the filter,
|
||||
# it's expanded to include ALL non-member events.
|
||||
# (Case: mixed filter)
|
||||
self.assertEqual(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
||||
"some.other.state.type": None,
|
||||
},
|
||||
include_others=False,
|
||||
).return_expanded(),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: {"@wombat:test", "@alicia:test"}},
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
# If there is a wildcard in the non-member portion of the filter,
|
||||
# it's expanded to include ALL non-member events.
|
||||
# (Case: non-member-only filter)
|
||||
self.assertEqual(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
"some.other.state.type": None,
|
||||
},
|
||||
include_others=False,
|
||||
).return_expanded(),
|
||||
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
|
||||
)
|
||||
self.assertEqual(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
"some.other.state.type": None,
|
||||
"yet.another.state.type": {"wombat"},
|
||||
},
|
||||
include_others=False,
|
||||
).return_expanded(),
|
||||
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
|
||||
)
|
||||
|
||||
0
tests/types/__init__.py
Normal file
0
tests/types/__init__.py
Normal file
627
tests/types/test_state.py
Normal file
627
tests/types/test_state.py
Normal file
@@ -0,0 +1,627 @@
|
||||
from frozendict import frozendict
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.types.state import StateFilter
|
||||
|
||||
from tests.unittest import TestCase
|
||||
|
||||
|
||||
class StateFilterDifferenceTestCase(TestCase):
|
||||
def assert_difference(
|
||||
self, minuend: StateFilter, subtrahend: StateFilter, expected: StateFilter
|
||||
) -> None:
|
||||
self.assertEqual(
|
||||
minuend.approx_difference(subtrahend),
|
||||
expected,
|
||||
f"StateFilter difference not correct:\n\n\t{minuend!r}\nminus\n\t{subtrahend!r}\nwas\n\t{minuend.approx_difference(subtrahend)}\nexpected\n\t{expected}",
|
||||
)
|
||||
|
||||
def test_state_filter_difference_no_include_other_minus_no_include_other(
|
||||
self,
|
||||
) -> None:
|
||||
"""
|
||||
Tests the StateFilter.approx_difference method
|
||||
where, in a.approx_difference(b), both a and b do not have the
|
||||
include_others flag set.
|
||||
"""
|
||||
# (wildcard on state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.Create: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze({EventTypes.Create: None}, include_others=False),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (specific state keys)
|
||||
# This one is an over-approximation because we can't represent
|
||||
# 'all state keys except a few named examples'
|
||||
self.assert_difference(
|
||||
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: {"@wombat:spqr"}},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.CanonicalAlias: {""}},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (specific state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr"},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
def test_state_filter_difference_include_other_minus_no_include_other(self) -> None:
|
||||
"""
|
||||
Tests the StateFilter.approx_difference method
|
||||
where, in a.approx_difference(b), only a has the include_others flag set.
|
||||
"""
|
||||
# (wildcard on state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.Create: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Create: None,
|
||||
EventTypes.Member: set(),
|
||||
EventTypes.CanonicalAlias: set(),
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (specific state keys)
|
||||
# This one is an over-approximation because we can't represent
|
||||
# 'all state keys except a few named examples'
|
||||
# This also shows that the resultant state filter is normalised.
|
||||
self.assert_difference(
|
||||
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr"},
|
||||
EventTypes.Create: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter(types=frozendict(), include_others=True),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter(
|
||||
types=frozendict(),
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (specific state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr"},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
def test_state_filter_difference_include_other_minus_include_other(self) -> None:
|
||||
"""
|
||||
Tests the StateFilter.approx_difference method
|
||||
where, in a.approx_difference(b), both a and b have the include_others
|
||||
flag set.
|
||||
"""
|
||||
# (wildcard on state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.Create: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter(types=frozendict(), include_others=False),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (specific state keys)
|
||||
# This one is an over-approximation because we can't represent
|
||||
# 'all state keys except a few named examples'
|
||||
self.assert_difference(
|
||||
StateFilter.freeze({EventTypes.Member: None}, include_others=True),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter(
|
||||
types=frozendict(),
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (specific state keys)
|
||||
# This one is an over-approximation because we can't represent
|
||||
# 'all state keys except a few named examples'
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
EventTypes.Create: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr"},
|
||||
EventTypes.Create: set(),
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@spqr:spqr"},
|
||||
EventTypes.Create: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
def test_state_filter_difference_no_include_other_minus_include_other(self) -> None:
|
||||
"""
|
||||
Tests the StateFilter.approx_difference method
|
||||
where, in a.approx_difference(b), only b has the include_others flag set.
|
||||
"""
|
||||
# (wildcard on state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.Create: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None, EventTypes.CanonicalAlias: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter(types=frozendict(), include_others=False),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (specific state keys)
|
||||
# This one is an over-approximation because we can't represent
|
||||
# 'all state keys except a few named examples'
|
||||
self.assert_difference(
|
||||
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: {"@wombat:spqr"}},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze({EventTypes.Member: None}, include_others=False),
|
||||
)
|
||||
|
||||
# (wildcard on state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (wildcard on state keys):
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter(
|
||||
types=frozendict(),
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (specific state keys)
|
||||
# This one is an over-approximation because we can't represent
|
||||
# 'all state keys except a few named examples'
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr"},
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@spqr:spqr"},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# (specific state keys) - (no state keys)
|
||||
self.assert_difference(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
EventTypes.CanonicalAlias: {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: set(),
|
||||
},
|
||||
include_others=True,
|
||||
),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:spqr", "@spqr:spqr"},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
def test_state_filter_difference_simple_cases(self) -> None:
|
||||
"""
|
||||
Tests some very simple cases of the StateFilter approx_difference,
|
||||
that are not explicitly tested by the more in-depth tests.
|
||||
"""
|
||||
|
||||
self.assert_difference(StateFilter.all(), StateFilter.all(), StateFilter.none())
|
||||
|
||||
self.assert_difference(
|
||||
StateFilter.all(),
|
||||
StateFilter.none(),
|
||||
StateFilter.all(),
|
||||
)
|
||||
|
||||
|
||||
class StateFilterTestCase(TestCase):
|
||||
def test_return_expanded(self) -> None:
|
||||
"""
|
||||
Tests the behaviour of the return_expanded() function that expands
|
||||
StateFilters to include more state types (for the sake of cache hit rate).
|
||||
"""
|
||||
|
||||
self.assertEqual(StateFilter.all().return_expanded(), StateFilter.all())
|
||||
|
||||
self.assertEqual(StateFilter.none().return_expanded(), StateFilter.none())
|
||||
|
||||
# Concrete-only state filters stay the same
|
||||
# (Case: mixed filter)
|
||||
self.assertEqual(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
||||
"some.other.state.type": {""},
|
||||
},
|
||||
include_others=False,
|
||||
).return_expanded(),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
||||
"some.other.state.type": {""},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Concrete-only state filters stay the same
|
||||
# (Case: non-member-only filter)
|
||||
self.assertEqual(
|
||||
StateFilter.freeze(
|
||||
{"some.other.state.type": {""}}, include_others=False
|
||||
).return_expanded(),
|
||||
StateFilter.freeze({"some.other.state.type": {""}}, include_others=False),
|
||||
)
|
||||
|
||||
# Concrete-only state filters stay the same
|
||||
# (Case: member-only filter)
|
||||
self.assertEqual(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
||||
},
|
||||
include_others=False,
|
||||
).return_expanded(),
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
||||
},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Wildcard member-only state filters stay the same
|
||||
self.assertEqual(
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
).return_expanded(),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: None},
|
||||
include_others=False,
|
||||
),
|
||||
)
|
||||
|
||||
# If there is a wildcard in the non-member portion of the filter,
|
||||
# it's expanded to include ALL non-member events.
|
||||
# (Case: mixed filter)
|
||||
self.assertEqual(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
EventTypes.Member: {"@wombat:test", "@alicia:test"},
|
||||
"some.other.state.type": None,
|
||||
},
|
||||
include_others=False,
|
||||
).return_expanded(),
|
||||
StateFilter.freeze(
|
||||
{EventTypes.Member: {"@wombat:test", "@alicia:test"}},
|
||||
include_others=True,
|
||||
),
|
||||
)
|
||||
|
||||
# If there is a wildcard in the non-member portion of the filter,
|
||||
# it's expanded to include ALL non-member events.
|
||||
# (Case: non-member-only filter)
|
||||
self.assertEqual(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
"some.other.state.type": None,
|
||||
},
|
||||
include_others=False,
|
||||
).return_expanded(),
|
||||
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
|
||||
)
|
||||
self.assertEqual(
|
||||
StateFilter.freeze(
|
||||
{
|
||||
"some.other.state.type": None,
|
||||
"yet.another.state.type": {"wombat"},
|
||||
},
|
||||
include_others=False,
|
||||
).return_expanded(),
|
||||
StateFilter.freeze({EventTypes.Member: set()}, include_others=True),
|
||||
)
|
||||
Reference in New Issue
Block a user