Compare commits
5 Commits
release-v1
...
hughns/msc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a6caec1adc | ||
|
|
9979d126d8 | ||
|
|
92cda2eb49 | ||
|
|
2aa0b5287a | ||
|
|
273e3b60ce |
@@ -137,7 +137,7 @@ fn get_runtime<'a>(reactor: &Bound<'a, PyAny>) -> PyResult<PyRef<'a, PyTokioRunt
|
||||
static DEFER: OnceCell<PyObject> = OnceCell::new();
|
||||
|
||||
/// Access to the `twisted.internet.defer` module.
|
||||
fn defer(py: Python<'_>) -> PyResult<&Bound<PyAny>> {
|
||||
fn defer(py: Python<'_>) -> PyResult<&Bound<'_, PyAny>> {
|
||||
Ok(DEFER
|
||||
.get_or_try_init(|| py.import("twisted.internet.defer").map(Into::into))?
|
||||
.bind(py))
|
||||
|
||||
@@ -11,6 +11,7 @@ pub mod http;
|
||||
pub mod http_client;
|
||||
pub mod identifier;
|
||||
pub mod matrix_const;
|
||||
pub mod msc4388_rendezvous;
|
||||
pub mod push;
|
||||
pub mod rendezvous;
|
||||
pub mod segmenter;
|
||||
@@ -54,6 +55,7 @@ fn synapse_rust(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
events::register_module(py, m)?;
|
||||
http_client::register_module(py, m)?;
|
||||
rendezvous::register_module(py, m)?;
|
||||
msc4388_rendezvous::register_module(py, m)?;
|
||||
segmenter::register_module(py, m)?;
|
||||
|
||||
Ok(())
|
||||
|
||||
381
rust/src/msc4388_rendezvous/mod.rs
Normal file
381
rust/src/msc4388_rendezvous/mod.rs
Normal file
@@ -0,0 +1,381 @@
|
||||
/*
|
||||
* This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
*
|
||||
* Copyright (C) 2025 Element Creations, Ltd
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as
|
||||
* published by the Free Software Foundation, either version 3 of the
|
||||
* License, or (at your option) any later version.
|
||||
*
|
||||
* See the GNU Affero General Public License for more details:
|
||||
* <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
*/
|
||||
|
||||
use std::{
|
||||
collections::BTreeMap,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use bytes::Bytes;
|
||||
use headers::{
|
||||
AccessControlAllowHeaders, AccessControlAllowMethods, AccessControlAllowOrigin, HeaderMapExt,
|
||||
};
|
||||
use http::header::HeaderName;
|
||||
use http::{header, HeaderMap, Method, Response, StatusCode};
|
||||
use pyo3::{
|
||||
pyclass, pymethods,
|
||||
types::{PyAnyMethods, PyModule, PyModuleMethods},
|
||||
Bound, IntoPyObject, Py, PyAny, PyObject, PyResult, Python,
|
||||
};
|
||||
use ulid::Ulid;
|
||||
|
||||
use self::session::Session;
|
||||
use crate::{
|
||||
errors::{NotFoundError, SynapseError},
|
||||
http::{http_request_from_twisted, http_response_to_twisted},
|
||||
UnwrapInfallible,
|
||||
};
|
||||
|
||||
mod session;
|
||||
|
||||
// Annoyingly we need to set the normal CORS headers on every response as the Python layer doesn't do it for us.
|
||||
// List is taken from https://spec.matrix.org/v1.16/client-server-api/#web-browser-clients
|
||||
fn prepare_headers(headers: &mut HeaderMap) {
|
||||
headers.typed_insert(AccessControlAllowOrigin::ANY);
|
||||
headers.typed_insert(AccessControlAllowMethods::from_iter([
|
||||
Method::POST,
|
||||
Method::GET,
|
||||
Method::PUT,
|
||||
Method::DELETE,
|
||||
Method::OPTIONS,
|
||||
]));
|
||||
headers.typed_insert(AccessControlAllowHeaders::from_iter([
|
||||
HeaderName::from_static("x-requested-with"),
|
||||
header::CONTENT_TYPE,
|
||||
header::AUTHORIZATION,
|
||||
]));
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
struct MSC4388RendezvousHandler {
|
||||
clock: PyObject,
|
||||
sessions: BTreeMap<Ulid, Session>,
|
||||
capacity: usize,
|
||||
max_content_length: u64,
|
||||
ttl: Duration,
|
||||
}
|
||||
|
||||
impl MSC4388RendezvousHandler {
|
||||
/// Check the length of the data parameter and throw error if invalid.
|
||||
fn check_data_length(&self, data: &str) -> PyResult<()> {
|
||||
let data_length = data.len() as u64;
|
||||
if data_length > self.max_content_length {
|
||||
let mut headers = HeaderMap::new();
|
||||
prepare_headers(&mut headers);
|
||||
|
||||
return Err(SynapseError::new(
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
"Payload too large".to_owned(),
|
||||
"M_TOO_LARGE",
|
||||
None,
|
||||
Some(headers),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Evict expired sessions and remove the oldest sessions until we're under the capacity.
|
||||
fn evict(&mut self, now: SystemTime) {
|
||||
// First remove all the entries which expired
|
||||
self.sessions.retain(|_, session| !session.expired(now));
|
||||
|
||||
// Then we remove the oldest entries until we're under the limit
|
||||
while self.sessions.len() > self.capacity {
|
||||
self.sessions.pop_first();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pymethods]
|
||||
impl MSC4388RendezvousHandler {
|
||||
#[new]
|
||||
#[pyo3(signature = (homeserver, /, capacity=100, max_content_length=4*1024, eviction_interval=60*1000, ttl=2*60*1000))]
|
||||
fn new(
|
||||
py: Python<'_>,
|
||||
homeserver: &Bound<'_, PyAny>,
|
||||
capacity: usize,
|
||||
max_content_length: u64,
|
||||
eviction_interval: u64,
|
||||
ttl: u64,
|
||||
) -> PyResult<Py<Self>> {
|
||||
let clock = homeserver
|
||||
.call_method0("get_clock")?
|
||||
.into_pyobject(py)
|
||||
.unwrap_infallible()
|
||||
.unbind();
|
||||
|
||||
// Construct a Python object so that we can get a reference to the
|
||||
// evict method and schedule it to run.
|
||||
let self_ = Py::new(
|
||||
py,
|
||||
Self {
|
||||
clock,
|
||||
sessions: BTreeMap::new(),
|
||||
capacity,
|
||||
max_content_length,
|
||||
ttl: Duration::from_millis(ttl),
|
||||
},
|
||||
)?;
|
||||
|
||||
let evict = self_.getattr(py, "_evict")?;
|
||||
homeserver.call_method0("get_clock")?.call_method(
|
||||
"looping_call",
|
||||
(evict, eviction_interval),
|
||||
None,
|
||||
)?;
|
||||
|
||||
Ok(self_)
|
||||
}
|
||||
|
||||
fn _evict(&mut self, py: Python<'_>) -> PyResult<()> {
|
||||
let clock = self.clock.bind(py);
|
||||
let now: u64 = clock.call_method0("time_msec")?.extract()?;
|
||||
let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
|
||||
self.evict(now);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_post(&mut self, py: Python<'_>, twisted_request: &Bound<'_, PyAny>) -> PyResult<()> {
|
||||
let request = http_request_from_twisted(twisted_request)?;
|
||||
|
||||
let clock = self.clock.bind(py);
|
||||
let now: u64 = clock.call_method0("time_msec")?.extract()?;
|
||||
let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
|
||||
|
||||
// We trigger an immediate eviction if we're at 2x the capacity
|
||||
if self.sessions.len() >= self.capacity * 2 {
|
||||
self.evict(now);
|
||||
}
|
||||
|
||||
// Generate a new ULID for the session from the current time.
|
||||
let id = Ulid::from_datetime(now);
|
||||
|
||||
// parse JSON body out of request
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_slice(&request.into_body()).map_err(|_| {
|
||||
let mut headers = HeaderMap::new();
|
||||
prepare_headers(&mut headers);
|
||||
|
||||
SynapseError::new(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Invalid JSON in request body".to_owned(),
|
||||
"M_INVALID_PARAM",
|
||||
None,
|
||||
Some(headers),
|
||||
)
|
||||
})?;
|
||||
|
||||
let data: String = json["data"].as_str().map(|s| s.to_owned()).ok_or_else(|| {
|
||||
let mut headers = HeaderMap::new();
|
||||
prepare_headers(&mut headers);
|
||||
|
||||
SynapseError::new(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Missing 'data' field in JSON body".to_owned(),
|
||||
"M_INVALID_PARAM",
|
||||
None,
|
||||
Some(headers),
|
||||
)
|
||||
})?;
|
||||
|
||||
self.check_data_length(&data)?;
|
||||
|
||||
let session = Session::new(id, data, now, self.ttl);
|
||||
|
||||
let response_body = serde_json::to_string(&session.post_response()).map_err(|_| {
|
||||
let mut headers = HeaderMap::new();
|
||||
prepare_headers(&mut headers);
|
||||
|
||||
SynapseError::new(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to serialize response".to_owned(),
|
||||
"M_UNKNOWN",
|
||||
None,
|
||||
Some(headers),
|
||||
)
|
||||
})?;
|
||||
let mut response = Response::new(response_body.as_bytes());
|
||||
*response.status_mut() = StatusCode::OK;
|
||||
let headers = response.headers_mut();
|
||||
prepare_headers(headers);
|
||||
http_response_to_twisted(twisted_request, response)?;
|
||||
|
||||
self.sessions.insert(id, session);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_get(
|
||||
&mut self,
|
||||
py: Python<'_>,
|
||||
twisted_request: &Bound<'_, PyAny>,
|
||||
id: &str,
|
||||
) -> PyResult<()> {
|
||||
let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?;
|
||||
let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
|
||||
|
||||
let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?;
|
||||
let session = self
|
||||
.sessions
|
||||
.get(&id)
|
||||
.filter(|s| !s.expired(now))
|
||||
.ok_or_else(NotFoundError::new)?;
|
||||
|
||||
let response_body = serde_json::to_string(&session.get_response()).map_err(|_| {
|
||||
let mut headers = HeaderMap::new();
|
||||
prepare_headers(&mut headers);
|
||||
|
||||
SynapseError::new(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to serialize response".to_owned(),
|
||||
"M_UNKNOWN",
|
||||
None,
|
||||
Some(headers),
|
||||
)
|
||||
})?;
|
||||
let mut response = Response::new(response_body.as_bytes());
|
||||
*response.status_mut() = StatusCode::OK;
|
||||
prepare_headers(response.headers_mut());
|
||||
http_response_to_twisted(twisted_request, response)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_put(
|
||||
&mut self,
|
||||
py: Python<'_>,
|
||||
twisted_request: &Bound<'_, PyAny>,
|
||||
id: &str,
|
||||
) -> PyResult<()> {
|
||||
let request = http_request_from_twisted(twisted_request)?;
|
||||
|
||||
// parse JSON body out of request
|
||||
let json: serde_json::Value =
|
||||
serde_json::from_slice(&request.into_body()).map_err(|_| {
|
||||
let mut headers = HeaderMap::new();
|
||||
prepare_headers(&mut headers);
|
||||
|
||||
SynapseError::new(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Invalid JSON in request body".to_owned(),
|
||||
"M_INVALID_PARAM",
|
||||
None,
|
||||
Some(headers),
|
||||
)
|
||||
})?;
|
||||
|
||||
let sequence_token: String = json["sequence_token"]
|
||||
.as_str()
|
||||
.map(|s| s.to_owned())
|
||||
.ok_or_else(|| {
|
||||
let mut headers = HeaderMap::new();
|
||||
prepare_headers(&mut headers);
|
||||
|
||||
SynapseError::new(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Missing 'sequence_token' field in JSON body".to_owned(),
|
||||
"M_INVALID_PARAM",
|
||||
None,
|
||||
Some(headers),
|
||||
)
|
||||
})?;
|
||||
|
||||
let data: String = json["data"].as_str().map(|s| s.to_owned()).ok_or_else(|| {
|
||||
let mut headers = HeaderMap::new();
|
||||
prepare_headers(&mut headers);
|
||||
|
||||
SynapseError::new(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Missing 'data' field in JSON body".to_owned(),
|
||||
"M_INVALID_PARAM",
|
||||
None,
|
||||
Some(headers),
|
||||
)
|
||||
})?;
|
||||
|
||||
self.check_data_length(&data)?;
|
||||
|
||||
let now: u64 = self.clock.call_method0(py, "time_msec")?.extract(py)?;
|
||||
let now = SystemTime::UNIX_EPOCH + Duration::from_millis(now);
|
||||
|
||||
let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?;
|
||||
let session = self
|
||||
.sessions
|
||||
.get_mut(&id)
|
||||
.filter(|s| !s.expired(now))
|
||||
.ok_or_else(NotFoundError::new)?;
|
||||
|
||||
if !session.sequence_token().eq(&sequence_token) {
|
||||
let mut headers = HeaderMap::new();
|
||||
prepare_headers(&mut headers);
|
||||
|
||||
return Err(SynapseError::new(
|
||||
StatusCode::CONFLICT,
|
||||
"sequence_token does not match".to_owned(),
|
||||
"IO_ELEMENT_MSC4388_CONCURRENT_WRITE",
|
||||
None,
|
||||
Some(headers),
|
||||
));
|
||||
}
|
||||
|
||||
session.update(data, now);
|
||||
|
||||
let response_body = serde_json::to_string(&session.put_response()).map_err(|_| {
|
||||
SynapseError::new(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to serialize response".to_owned(),
|
||||
"M_UNKNOWN",
|
||||
None,
|
||||
None,
|
||||
)
|
||||
})?;
|
||||
let mut response = Response::new(response_body.as_bytes());
|
||||
*response.status_mut() = StatusCode::OK;
|
||||
prepare_headers(response.headers_mut());
|
||||
http_response_to_twisted(twisted_request, response)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_delete(&mut self, twisted_request: &Bound<'_, PyAny>, id: &str) -> PyResult<()> {
|
||||
let _request = http_request_from_twisted(twisted_request)?;
|
||||
|
||||
let id: Ulid = id.parse().map_err(|_| NotFoundError::new())?;
|
||||
let _session = self.sessions.remove(&id).ok_or_else(NotFoundError::new)?;
|
||||
|
||||
let mut response = Response::new(Bytes::new());
|
||||
*response.status_mut() = StatusCode::OK;
|
||||
prepare_headers(response.headers_mut());
|
||||
http_response_to_twisted(twisted_request, response)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
||||
let child_module = PyModule::new(py, "msc4388_rendezvous")?;
|
||||
|
||||
child_module.add_class::<MSC4388RendezvousHandler>()?;
|
||||
|
||||
m.add_submodule(&child_module)?;
|
||||
|
||||
// We need to manually add the module to sys.modules to make `from
|
||||
// synapse.synapse_rust import rendezvous` work.
|
||||
py.import("sys")?
|
||||
.getattr("modules")?
|
||||
.set_item("synapse.synapse_rust.msc4388_rendezvous", child_module)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
109
rust/src/msc4388_rendezvous/session.rs
Normal file
109
rust/src/msc4388_rendezvous/session.rs
Normal file
@@ -0,0 +1,109 @@
|
||||
/*
|
||||
* This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
*
|
||||
* Copyright (C) 2025 Element Creations, Ltd
|
||||
*
|
||||
* This program is free software: you can redistribute it and/or modify
|
||||
* it under the terms of the GNU Affero General Public License as
|
||||
* published by the Free Software Foundation, either version 3 of the
|
||||
* License, or (at your option) any later version.
|
||||
*
|
||||
* See the GNU Affero General Public License for more details:
|
||||
* <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
*/
|
||||
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
|
||||
use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine as _};
|
||||
use serde::Serialize;
|
||||
use sha2::{Digest, Sha256};
|
||||
use ulid::Ulid;
|
||||
|
||||
/// A single session, containing data, metadata, and expiry information.
|
||||
pub struct Session {
|
||||
id: Ulid,
|
||||
hash: [u8; 32],
|
||||
data: String,
|
||||
last_modified: SystemTime,
|
||||
expires: SystemTime,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct PostResponse {
|
||||
id: String,
|
||||
sequence_token: String,
|
||||
expires_ts: u64,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct GetResponse {
|
||||
data: String,
|
||||
sequence_token: String,
|
||||
expires_ts: u64,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct PutResponse {
|
||||
sequence_token: String,
|
||||
}
|
||||
|
||||
impl Session {
|
||||
/// Create a new session with the given data and time-to-live.
|
||||
pub fn new(id: Ulid, data: String, now: SystemTime, ttl: Duration) -> Self {
|
||||
let hash = Sha256::digest(&data).into();
|
||||
Self {
|
||||
id,
|
||||
hash,
|
||||
data,
|
||||
expires: now + ttl,
|
||||
last_modified: now,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the session has expired at the given time.
|
||||
pub fn expired(&self, now: SystemTime) -> bool {
|
||||
self.expires <= now
|
||||
}
|
||||
|
||||
/// Update the session with new data and last modified time.
|
||||
pub fn update(&mut self, data: String, now: SystemTime) {
|
||||
self.hash = Sha256::digest(&data).into();
|
||||
self.data = data;
|
||||
self.last_modified = now;
|
||||
}
|
||||
|
||||
/// The sequence token for the session.
|
||||
pub fn sequence_token(&self) -> String {
|
||||
URL_SAFE_NO_PAD.encode(self.hash)
|
||||
}
|
||||
|
||||
pub fn get_response(&self) -> GetResponse {
|
||||
GetResponse {
|
||||
data: self.data.clone(),
|
||||
sequence_token: self.sequence_token(),
|
||||
expires_ts: self
|
||||
.expires
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn post_response(&self) -> PostResponse {
|
||||
PostResponse {
|
||||
id: self.id.to_string(),
|
||||
sequence_token: self.sequence_token(),
|
||||
expires_ts: self
|
||||
.expires
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn put_response(&self) -> PutResponse {
|
||||
PutResponse {
|
||||
sequence_token: self.sequence_token(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -523,7 +523,7 @@ class ExperimentalConfig(Config):
|
||||
"msc4069_profile_inhibit_propagation", False
|
||||
)
|
||||
|
||||
# MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code
|
||||
# MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code - 2024 version:
|
||||
self.msc4108_enabled = experimental.get("msc4108_enabled", False)
|
||||
|
||||
self.msc4108_delegation_endpoint: Optional[str] = experimental.get(
|
||||
@@ -548,6 +548,25 @@ class ExperimentalConfig(Config):
|
||||
("experimental", "msc4108_delegation_endpoint"),
|
||||
)
|
||||
|
||||
# MSC4388: Secure out-of-band channel for sign in with QR:
|
||||
msc4388_mode = experimental.get("msc4388_mode", "off")
|
||||
|
||||
if ["off", "public", "authenticated"].count(msc4388_mode) != 1:
|
||||
raise ConfigError(
|
||||
"msc4388_mode must be one of 'off', 'public' or 'authenticated'",
|
||||
("experimental", "msc4388_mode"),
|
||||
)
|
||||
self.msc4388_enabled: bool = msc4388_mode != "off"
|
||||
self.msc4388_requires_authentication: bool = msc4388_mode == "authenticated"
|
||||
|
||||
if self.msc4388_enabled and not (
|
||||
config.get("matrix_authentication_service") or {}
|
||||
).get("enabled", False):
|
||||
raise ConfigError(
|
||||
"MSC4388 requires matrix_authentication_service to be enabled",
|
||||
("experimental", "msc4388_enabled"),
|
||||
)
|
||||
|
||||
# MSC4133: Custom profile fields
|
||||
self.msc4133_enabled: bool = experimental.get("msc4133_enabled", False)
|
||||
|
||||
|
||||
@@ -68,9 +68,55 @@ class MSC4108RendezvousServlet(RestServlet):
|
||||
self._handler.handle_post(request)
|
||||
|
||||
|
||||
class MSC4388CreateRendezvousServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
"/io.element.msc4388/rendezvous$", releases=[], v1=False, unstable=True
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer") -> None:
|
||||
super().__init__()
|
||||
self._handler = hs.get_msc4388_rendezvous_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.require_authentication = (
|
||||
hs.config.experimental.msc4388_requires_authentication
|
||||
)
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> None:
|
||||
if self.require_authentication:
|
||||
# This will raise if the user is not authenticated
|
||||
await self.auth.get_user_by_req(request)
|
||||
self._handler.handle_post(request)
|
||||
|
||||
|
||||
class MSC4388UpdateRendezvousServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
"/io.element.msc4388/rendezvous/(?P<rendezvous_id>[^/]+)$",
|
||||
releases=[],
|
||||
v1=False,
|
||||
unstable=True,
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer") -> None:
|
||||
super().__init__()
|
||||
self._handler = hs.get_msc4388_rendezvous_handler()
|
||||
|
||||
def on_GET(self, request: SynapseRequest, rendezvous_id: str) -> None:
|
||||
self._handler.handle_get(request, rendezvous_id)
|
||||
|
||||
def on_PUT(self, request: SynapseRequest, rendezvous_id: str) -> None:
|
||||
self._handler.handle_put(request, rendezvous_id)
|
||||
|
||||
def on_DELETE(self, request: SynapseRequest, rendezvous_id: str) -> None:
|
||||
self._handler.handle_delete(request, rendezvous_id)
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
if hs.config.experimental.msc4108_enabled:
|
||||
MSC4108RendezvousServlet(hs).register(http_server)
|
||||
|
||||
if hs.config.experimental.msc4108_delegation_endpoint is not None:
|
||||
MSC4108DelegationRendezvousServlet(hs).register(http_server)
|
||||
|
||||
if hs.config.experimental.msc4388_enabled:
|
||||
MSC4388CreateRendezvousServlet(hs).register(http_server)
|
||||
MSC4388UpdateRendezvousServlet(hs).register(http_server)
|
||||
|
||||
@@ -161,7 +161,7 @@ class VersionsRestServlet(RestServlet):
|
||||
"org.matrix.msc4069": self.config.experimental.msc4069_profile_inhibit_propagation,
|
||||
# Allows clients to handle push for encrypted events.
|
||||
"org.matrix.msc4028": self.config.experimental.msc4028_push_encrypted_events,
|
||||
# MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code
|
||||
# MSC4108: Mechanism to allow OIDC sign in and E2EE set up via QR code - 2024 version
|
||||
"org.matrix.msc4108": (
|
||||
self.config.experimental.msc4108_enabled
|
||||
or (
|
||||
@@ -169,6 +169,8 @@ class VersionsRestServlet(RestServlet):
|
||||
is not None
|
||||
)
|
||||
),
|
||||
# MSC4388: Secure out-of-band channel for sign in with QR
|
||||
"io.element.msc4388": (self.config.experimental.msc4388_enabled),
|
||||
# MSC4140: Delayed events
|
||||
"org.matrix.msc4140": bool(self.config.server.max_event_delay_ms),
|
||||
# Simplified sliding sync
|
||||
|
||||
@@ -170,6 +170,7 @@ from synapse.state import StateHandler, StateResolutionHandler
|
||||
from synapse.storage import Databases
|
||||
from synapse.storage.controllers import StorageControllers
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.synapse_rust.msc4388_rendezvous import MSC4388RendezvousHandler
|
||||
from synapse.synapse_rust.rendezvous import RendezvousHandler
|
||||
from synapse.types import DomainSpecificString, ISynapseReactor
|
||||
from synapse.util import SYNAPSE_VERSION
|
||||
@@ -1156,6 +1157,10 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
def get_rendezvous_handler(self) -> RendezvousHandler:
|
||||
return RendezvousHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_msc4388_rendezvous_handler(self) -> MSC4388RendezvousHandler:
|
||||
return MSC4388RendezvousHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_outbound_redis_connection(self) -> "ConnectionHandler":
|
||||
"""
|
||||
|
||||
30
synapse/synapse_rust/msc4388_rendezvous.pyi
Normal file
30
synapse/synapse_rust/msc4388_rendezvous.pyi
Normal file
@@ -0,0 +1,30 @@
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 Element Creations, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
from twisted.web.iweb import IRequest
|
||||
|
||||
from synapse.server import HomeServer
|
||||
|
||||
class MSC4388RendezvousHandler:
|
||||
def __init__(
|
||||
self,
|
||||
homeserver: HomeServer,
|
||||
/,
|
||||
capacity: int = 100, # This should be configurable
|
||||
max_content_length: int = 4 * 1024, # MSC4388 specifies maximum of 4KB
|
||||
eviction_interval: int = 60 * 1000,
|
||||
ttl: int = 2 * 60 * 1000, # MSC4388 specifies minimum of 120 seconds
|
||||
) -> None: ...
|
||||
def handle_post(self, request: IRequest) -> None: ...
|
||||
def handle_get(self, request: IRequest, session_id: str) -> None: ...
|
||||
def handle_put(self, request: IRequest, session_id: str) -> None: ...
|
||||
def handle_delete(self, request: IRequest, session_id: str) -> None: ...
|
||||
632
tests/rest/client/test_msc4388_rendezvous.py
Normal file
632
tests/rest/client/test_msc4388_rendezvous.py
Normal file
@@ -0,0 +1,632 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 Element Creations, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
|
||||
import json
|
||||
import urllib.parse
|
||||
from typing import Any, Mapping
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.internet.testing import MemoryReactor
|
||||
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, rendezvous
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import UserID
|
||||
from synapse.util.clock import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.unittest import checked_cast, override_config
|
||||
from tests.utils import HAS_AUTHLIB
|
||||
|
||||
rz_endpoint = "/_matrix/client/unstable/io.element.msc4388/rendezvous"
|
||||
|
||||
|
||||
class RendezvousServletTestCase(unittest.HomeserverTestCase):
|
||||
"""
|
||||
Test the experimental MSC4388 rendezvous endpoint.
|
||||
"""
|
||||
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
login.register_servlets,
|
||||
rendezvous.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.hs = self.setup_test_homeserver()
|
||||
return self.hs
|
||||
|
||||
def setup_mock_oauth(self) -> None:
|
||||
"""
|
||||
This isn't a very elegant way to mock the OAuth API, but it works for our purposes.
|
||||
"""
|
||||
|
||||
# Import this here so that we've checked that authlib is available.
|
||||
from synapse.api.auth.mas import MasDelegatedAuth
|
||||
|
||||
self.auth = checked_cast(MasDelegatedAuth, self.hs.get_auth())
|
||||
|
||||
self._rust_client = Mock(spec=["post"])
|
||||
self._rust_client.post = self._mock_oauth_response
|
||||
self.auth._rust_http_client = self._rust_client
|
||||
|
||||
async def _mock_oauth_response(
|
||||
self,
|
||||
url: str,
|
||||
response_limit: int,
|
||||
headers: Mapping[str, str],
|
||||
request_body: str,
|
||||
) -> bytes:
|
||||
# get the token from the request body which is form encoded
|
||||
parsed_body = urllib.parse.parse_qs(request_body)
|
||||
token = parsed_body.get("token", [""])[0]
|
||||
|
||||
if not token.startswith("mock_token_"):
|
||||
return bytes(json.dumps({"active": False}).encode("utf-8"))
|
||||
token = token.replace("mock_token_", "")
|
||||
|
||||
username, device_id = token.split("_", 1)
|
||||
user_id = UserID(username, self.hs.hostname)
|
||||
store = self.hs.get_datastores().main
|
||||
|
||||
# Check th user exists in the store
|
||||
user_info = await store.get_user_by_id(user_id=user_id.to_string())
|
||||
if user_info is None:
|
||||
return bytes(json.dumps({"active": False}).encode("utf-8"))
|
||||
|
||||
# Check the device exists in the store
|
||||
device = await store.get_device(
|
||||
user_id=user_id.to_string(), device_id=device_id
|
||||
)
|
||||
if device is None:
|
||||
return bytes(json.dumps({"active": False}).encode("utf-8"))
|
||||
|
||||
return bytes(
|
||||
json.dumps(
|
||||
{
|
||||
"active": True,
|
||||
"scope": "urn:matrix:client:device:"
|
||||
+ device_id
|
||||
+ " urn:matrix:client:api:*",
|
||||
"username": username,
|
||||
}
|
||||
).encode("utf-8")
|
||||
)
|
||||
|
||||
def register_oauth_user(self, username: str, device_id: str) -> str:
|
||||
# Provision the user and the device
|
||||
store = self.hs.get_datastores().main
|
||||
user_id = UserID(username, self.hs.hostname)
|
||||
|
||||
self.get_success(store.register_user(user_id=user_id.to_string()))
|
||||
self.get_success(
|
||||
store.store_device(
|
||||
user_id=user_id.to_string(),
|
||||
device_id=device_id,
|
||||
initial_device_display_name=None,
|
||||
)
|
||||
)
|
||||
# Generate an access token for the device
|
||||
return "mock_token_" + username + "_" + device_id
|
||||
|
||||
def test_disabled(self) -> None:
|
||||
channel = self.make_request("POST", rz_endpoint, {}, access_token=None)
|
||||
self.assertEqual(channel.code, 404)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"experimental_features": {
|
||||
"msc4388_mode": "off",
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_off(self) -> None:
|
||||
channel = self.make_request("POST", rz_endpoint, {}, access_token=None)
|
||||
self.assertEqual(channel.code, 404)
|
||||
|
||||
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
@override_config(
|
||||
{
|
||||
"disable_registration": True,
|
||||
"matrix_authentication_service": {
|
||||
"enabled": True,
|
||||
"secret": "secret_value",
|
||||
"endpoint": "https://issuer",
|
||||
},
|
||||
"experimental_features": {
|
||||
"msc4388_mode": "public",
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_rendezvous_public(self) -> None:
|
||||
"""
|
||||
Test the MSC4108 rendezvous endpoint, including:
|
||||
- Creating a session
|
||||
- Getting the data back
|
||||
- Updating the data
|
||||
- Deleting the data
|
||||
- Sequence token handling
|
||||
"""
|
||||
# We can post arbitrary data to the endpoint
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
rz_endpoint,
|
||||
{"data": "foo=bar"},
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
rendezvous_id = channel.json_body["id"]
|
||||
sequence_token = channel.json_body["sequence_token"]
|
||||
expires_ts = channel.json_body["expires_ts"]
|
||||
self.assertGreater(expires_ts, self.hs.get_clock().time_msec())
|
||||
|
||||
session_endpoint = rz_endpoint + f"/{rendezvous_id}"
|
||||
|
||||
# We can get the data back
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["data"], "foo=bar")
|
||||
self.assertEqual(channel.json_body["sequence_token"], sequence_token)
|
||||
self.assertEqual(channel.json_body["expires_ts"], expires_ts)
|
||||
|
||||
# We can update the data
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
session_endpoint,
|
||||
{"sequence_token": sequence_token, "data": "foo=baz"},
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
old_sequence_token = sequence_token
|
||||
new_sequence_token = channel.json_body["sequence_token"]
|
||||
|
||||
# If we try to update it again with the old etag, it should fail
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
session_endpoint,
|
||||
{"sequence_token": old_sequence_token, "data": "bar=baz"},
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 409)
|
||||
self.assertEqual(
|
||||
channel.json_body["errcode"], "IO_ELEMENT_MSC4388_CONCURRENT_WRITE"
|
||||
)
|
||||
|
||||
# We should get the updated data
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["data"], "foo=baz")
|
||||
self.assertEqual(channel.json_body["sequence_token"], new_sequence_token)
|
||||
self.assertEqual(channel.json_body["expires_ts"], expires_ts)
|
||||
|
||||
# We can delete the data
|
||||
channel = self.make_request(
|
||||
"DELETE",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# If we try to get the data again, it should fail
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 404)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||
|
||||
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
@override_config(
|
||||
{
|
||||
"disable_registration": True,
|
||||
"matrix_authentication_service": {
|
||||
"enabled": True,
|
||||
"secret": "secret_value",
|
||||
"endpoint": "https://issuer",
|
||||
},
|
||||
"experimental_features": {
|
||||
"msc4388_mode": "authenticated",
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_rendezvous_requires_authentication(self) -> None:
|
||||
"""
|
||||
Test the MSC4108 rendezvous endpoint when configured with the mode authenticated, including:
|
||||
- Creating a session
|
||||
- Getting the data back
|
||||
- Updating the data
|
||||
- Deleting the data
|
||||
- Sequence token handling
|
||||
"""
|
||||
self.setup_mock_oauth()
|
||||
alice_token = self.register_oauth_user("alice", "device1")
|
||||
|
||||
# This should fail without authentication:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
rz_endpoint,
|
||||
{"data": "foo=bar"},
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 401)
|
||||
|
||||
# This should work as we are now authenticated
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
rz_endpoint,
|
||||
{"data": "foo=bar"},
|
||||
access_token=alice_token,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
rendezvous_id = channel.json_body["id"]
|
||||
sequence_token = channel.json_body["sequence_token"]
|
||||
expires_ts = channel.json_body["expires_ts"]
|
||||
self.assertGreater(expires_ts, self.hs.get_clock().time_msec())
|
||||
|
||||
session_endpoint = rz_endpoint + f"/{rendezvous_id}"
|
||||
|
||||
# We can get the data back without authentication
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["data"], "foo=bar")
|
||||
self.assertEqual(channel.json_body["sequence_token"], sequence_token)
|
||||
self.assertEqual(channel.json_body["expires_ts"], expires_ts)
|
||||
|
||||
# We can update the data without authentication
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
session_endpoint,
|
||||
{"sequence_token": sequence_token, "data": "foo=baz"},
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
new_sequence_token = channel.json_body["sequence_token"]
|
||||
|
||||
# We should get the updated data without authentication
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["data"], "foo=baz")
|
||||
self.assertEqual(channel.json_body["sequence_token"], new_sequence_token)
|
||||
self.assertEqual(channel.json_body["expires_ts"], expires_ts)
|
||||
|
||||
# We can delete the data without authentication
|
||||
channel = self.make_request(
|
||||
"DELETE",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# If we try to get the data again, it should fail
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 404)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
|
||||
|
||||
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
@override_config(
|
||||
{
|
||||
"disable_registration": True,
|
||||
"matrix_authentication_service": {
|
||||
"enabled": True,
|
||||
"secret": "secret_value",
|
||||
"endpoint": "https://issuer",
|
||||
},
|
||||
"experimental_features": {
|
||||
"msc4388_mode": "public",
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_expiration(self) -> None:
|
||||
"""
|
||||
Test that entries are evicted after a TTL.
|
||||
"""
|
||||
# Start a new session
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
rz_endpoint,
|
||||
{"data": "foo=bar"},
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
session_endpoint = rz_endpoint + "/" + channel.json_body["id"]
|
||||
|
||||
# Sanity check that we can get the data back
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["data"], "foo=bar")
|
||||
|
||||
# Advance the clock, TTL of entries is 2 minutes
|
||||
self.reactor.advance(120)
|
||||
|
||||
# Get the data back, it should be gone
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 404)
|
||||
|
||||
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
@override_config(
|
||||
{
|
||||
"disable_registration": True,
|
||||
"matrix_authentication_service": {
|
||||
"enabled": True,
|
||||
"secret": "secret_value",
|
||||
"endpoint": "https://issuer",
|
||||
},
|
||||
"experimental_features": {
|
||||
"msc4388_mode": "public",
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_capacity(self) -> None:
|
||||
"""
|
||||
Test that a capacity limit is enforced on the rendezvous sessions, as old
|
||||
entries are evicted at an interval when the limit is reached.
|
||||
"""
|
||||
# Start a new session
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
rz_endpoint,
|
||||
{"data": "foo=bar"},
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
session_endpoint = rz_endpoint + "/" + channel.json_body["id"]
|
||||
|
||||
# Sanity check that we can get the data back
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["data"], "foo=bar")
|
||||
|
||||
# We advance the clock to make sure that this entry is the "lowest" in the session list
|
||||
self.reactor.advance(1)
|
||||
|
||||
# Start a lot of new sessions
|
||||
for _ in range(100):
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
rz_endpoint,
|
||||
{"data": "foo=bar"},
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# Get the data back, it should still be there, as the eviction hasn't run yet
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# Advance the clock, as it will trigger the eviction
|
||||
self.reactor.advance(59)
|
||||
|
||||
# Get the data back, it should be gone
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 404)
|
||||
|
||||
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
@override_config(
|
||||
{
|
||||
"disable_registration": True,
|
||||
"matrix_authentication_service": {
|
||||
"enabled": True,
|
||||
"secret": "secret_value",
|
||||
"endpoint": "https://issuer",
|
||||
},
|
||||
"experimental_features": {
|
||||
"msc4388_mode": "public",
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_hard_capacity(self) -> None:
|
||||
"""
|
||||
Test that a hard capacity limit is enforced on the rendezvous sessions, as old
|
||||
entries are evicted immediately when the limit is reached.
|
||||
"""
|
||||
# Start a new session
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
rz_endpoint,
|
||||
{"data": "foo=bar"},
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
session_endpoint = rz_endpoint + "/" + channel.json_body["id"]
|
||||
# We advance the clock to make sure that this entry is the "lowest" in the session list
|
||||
self.reactor.advance(1)
|
||||
|
||||
# Sanity check that we can get the data back
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["data"], "foo=bar")
|
||||
|
||||
# Start a lot of new sessions
|
||||
for _ in range(200):
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
rz_endpoint,
|
||||
{"data": "foo=bar"},
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
|
||||
# Get the data back, it should already be gone as we hit the hard limit
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
session_endpoint,
|
||||
access_token=None,
|
||||
)
|
||||
|
||||
self.assertEqual(channel.code, 404)
|
||||
|
||||
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
@override_config(
|
||||
{
|
||||
"disable_registration": True,
|
||||
"matrix_authentication_service": {
|
||||
"enabled": True,
|
||||
"secret": "secret_value",
|
||||
"endpoint": "https://issuer",
|
||||
},
|
||||
"experimental_features": {
|
||||
"msc4388_mode": "public",
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_data_type(self) -> None:
|
||||
"""
|
||||
Test that the data field is restricted to string.
|
||||
"""
|
||||
invalid_datas: list[Any] = [123214, ["asd"], {"asd": "asdsad"}, None]
|
||||
|
||||
# We cannot post invalid non-string data field values to the endpoint
|
||||
for invalid_data in invalid_datas:
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
rz_endpoint,
|
||||
{"data": invalid_data},
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
|
||||
|
||||
# Make a valid request
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
rz_endpoint,
|
||||
{"data": "test"},
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
rendezvous_id = channel.json_body["id"]
|
||||
sequence_token = channel.json_body["sequence_token"]
|
||||
|
||||
session_endpoint = rz_endpoint + f"/{rendezvous_id}"
|
||||
|
||||
# We can't update the data with invalid data
|
||||
for invalid_data in invalid_datas:
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
session_endpoint,
|
||||
{"sequence_token": sequence_token, "data": invalid_data},
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 400)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
|
||||
|
||||
@unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
|
||||
@override_config(
|
||||
{
|
||||
"disable_registration": True,
|
||||
"matrix_authentication_service": {
|
||||
"enabled": True,
|
||||
"secret": "secret_value",
|
||||
"endpoint": "https://issuer",
|
||||
},
|
||||
"experimental_features": {
|
||||
"msc4388_mode": "public",
|
||||
},
|
||||
}
|
||||
)
|
||||
def test_max_length(self) -> None:
|
||||
"""
|
||||
Test that the data max length is restricted.
|
||||
"""
|
||||
too_long_data = "a" * 5000 # MSC4108 specifies 4KB max length
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
rz_endpoint,
|
||||
{"data": too_long_data},
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 413)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_TOO_LARGE")
|
||||
|
||||
# Make a valid request
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
rz_endpoint,
|
||||
{"data": "test"},
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 200)
|
||||
rendezvous_id = channel.json_body["id"]
|
||||
sequence_token = channel.json_body["sequence_token"]
|
||||
|
||||
session_endpoint = rz_endpoint + f"/{rendezvous_id}"
|
||||
|
||||
# We can't update the data with invalid data
|
||||
channel = self.make_request(
|
||||
"PUT",
|
||||
session_endpoint,
|
||||
{"sequence_token": sequence_token, "data": too_long_data},
|
||||
access_token=None,
|
||||
)
|
||||
self.assertEqual(channel.code, 413)
|
||||
self.assertEqual(channel.json_body["errcode"], "M_TOO_LARGE")
|
||||
Reference in New Issue
Block a user