|
|
|
|
@@ -12,58 +12,149 @@
|
|
|
|
|
* <https://www.gnu.org/licenses/agpl-3.0.html>.
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
use std::{collections::HashMap, future::Future, panic::AssertUnwindSafe, sync::LazyLock};
|
|
|
|
|
use std::{collections::HashMap, future::Future};
|
|
|
|
|
|
|
|
|
|
use anyhow::Context;
|
|
|
|
|
use futures::{FutureExt, TryStreamExt};
|
|
|
|
|
use pyo3::{exceptions::PyException, prelude::*, types::PyString};
|
|
|
|
|
use futures::TryStreamExt;
|
|
|
|
|
use once_cell::sync::OnceCell;
|
|
|
|
|
use pyo3::{create_exception, exceptions::PyException, prelude::*};
|
|
|
|
|
use reqwest::RequestBuilder;
|
|
|
|
|
use tokio::runtime::Runtime;
|
|
|
|
|
|
|
|
|
|
use crate::errors::HttpResponseException;
|
|
|
|
|
|
|
|
|
|
/// The tokio runtime that we're using to run async Rust libs.
|
|
|
|
|
static RUNTIME: LazyLock<Runtime> = LazyLock::new(|| {
|
|
|
|
|
tokio::runtime::Builder::new_multi_thread()
|
|
|
|
|
.worker_threads(4)
|
|
|
|
|
.enable_all()
|
|
|
|
|
.build()
|
|
|
|
|
.unwrap()
|
|
|
|
|
});
|
|
|
|
|
create_exception!(
|
|
|
|
|
synapse.synapse_rust.http_client,
|
|
|
|
|
RustPanicError,
|
|
|
|
|
PyException,
|
|
|
|
|
"A panic which happened in a Rust future"
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
/// A reference to the `Deferred` python class.
|
|
|
|
|
static DEFERRED_CLASS: LazyLock<PyObject> = LazyLock::new(|| {
|
|
|
|
|
Python::with_gil(|py| {
|
|
|
|
|
py.import("twisted.internet.defer")
|
|
|
|
|
.expect("module 'twisted.internet.defer' should be importable")
|
|
|
|
|
.getattr("Deferred")
|
|
|
|
|
.expect("module 'twisted.internet.defer' should have a 'Deferred' class")
|
|
|
|
|
.unbind()
|
|
|
|
|
})
|
|
|
|
|
});
|
|
|
|
|
impl RustPanicError {
|
|
|
|
|
fn from_panic(panic_err: &(dyn std::any::Any + Send + 'static)) -> PyErr {
|
|
|
|
|
// Apparently this is how you extract the panic message from a panic
|
|
|
|
|
let panic_message = if let Some(str_slice) = panic_err.downcast_ref::<&str>() {
|
|
|
|
|
str_slice
|
|
|
|
|
} else if let Some(string) = panic_err.downcast_ref::<String>() {
|
|
|
|
|
string
|
|
|
|
|
} else {
|
|
|
|
|
"unknown error"
|
|
|
|
|
};
|
|
|
|
|
Self::new_err(panic_message.to_owned())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// A reference to the twisted `reactor`.
|
|
|
|
|
static TWISTED_REACTOR: LazyLock<Py<PyModule>> = LazyLock::new(|| {
|
|
|
|
|
Python::with_gil(|py| {
|
|
|
|
|
py.import("twisted.internet.reactor")
|
|
|
|
|
.expect("module 'twisted.internet.reactor' should be importable")
|
|
|
|
|
.unbind()
|
|
|
|
|
})
|
|
|
|
|
});
|
|
|
|
|
/// This is the name of the attribute where we store the runtime on the reactor
|
|
|
|
|
static TOKIO_RUNTIME_ATTR: &str = "__synapse_rust_tokio_runtime";
|
|
|
|
|
|
|
|
|
|
/// A Python wrapper around a Tokio runtime.
|
|
|
|
|
///
|
|
|
|
|
/// This allows us to 'store' the runtime on the reactor instance, starting it
|
|
|
|
|
/// when the reactor starts, and stopping it when the reactor shuts down.
|
|
|
|
|
#[pyclass]
|
|
|
|
|
struct PyTokioRuntime {
|
|
|
|
|
runtime: Option<Runtime>,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[pymethods]
|
|
|
|
|
impl PyTokioRuntime {
|
|
|
|
|
fn start(&mut self) -> PyResult<()> {
|
|
|
|
|
// TODO: allow customization of the runtime like the number of threads
|
|
|
|
|
let runtime = tokio::runtime::Builder::new_multi_thread()
|
|
|
|
|
.worker_threads(4)
|
|
|
|
|
.enable_all()
|
|
|
|
|
.build()?;
|
|
|
|
|
|
|
|
|
|
self.runtime = Some(runtime);
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
fn shutdown(&mut self) -> PyResult<()> {
|
|
|
|
|
let runtime = self
|
|
|
|
|
.runtime
|
|
|
|
|
.take()
|
|
|
|
|
.context("Runtime was already shutdown")?;
|
|
|
|
|
|
|
|
|
|
// Dropping the runtime will shut it down
|
|
|
|
|
drop(runtime);
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl PyTokioRuntime {
|
|
|
|
|
/// Get the handle to the Tokio runtime, if it is running.
|
|
|
|
|
fn handle(&self) -> PyResult<&tokio::runtime::Handle> {
|
|
|
|
|
let handle = self
|
|
|
|
|
.runtime
|
|
|
|
|
.as_ref()
|
|
|
|
|
.context("Tokio runtime is not running")?
|
|
|
|
|
.handle();
|
|
|
|
|
|
|
|
|
|
Ok(handle)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Get a handle to the Tokio runtime stored on the reactor instance, or create
|
|
|
|
|
/// a new one.
|
|
|
|
|
fn runtime<'a>(reactor: &Bound<'a, PyAny>) -> PyResult<PyRef<'a, PyTokioRuntime>> {
|
|
|
|
|
if !reactor.hasattr(TOKIO_RUNTIME_ATTR)? {
|
|
|
|
|
install_runtime(reactor)?;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
get_runtime(reactor)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Install a new Tokio runtime on the reactor instance.
|
|
|
|
|
fn install_runtime(reactor: &Bound<PyAny>) -> PyResult<()> {
|
|
|
|
|
let py = reactor.py();
|
|
|
|
|
let runtime = PyTokioRuntime { runtime: None };
|
|
|
|
|
let runtime = runtime.into_pyobject(py)?;
|
|
|
|
|
|
|
|
|
|
// Attach the runtime to the reactor, starting it when the reactor is
|
|
|
|
|
// running, stopping it when the reactor is shutting down
|
|
|
|
|
reactor.call_method1("callWhenRunning", (runtime.getattr("start")?,))?;
|
|
|
|
|
reactor.call_method1(
|
|
|
|
|
"addSystemEventTrigger",
|
|
|
|
|
("after", "shutdown", runtime.getattr("shutdown")?),
|
|
|
|
|
)?;
|
|
|
|
|
reactor.setattr(TOKIO_RUNTIME_ATTR, runtime)?;
|
|
|
|
|
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Get a reference to a Tokio runtime handle stored on the reactor instance.
|
|
|
|
|
fn get_runtime<'a>(reactor: &Bound<'a, PyAny>) -> PyResult<PyRef<'a, PyTokioRuntime>> {
|
|
|
|
|
// This will raise if `TOKIO_RUNTIME_ATTR` is not set or if it is
|
|
|
|
|
// not a `Runtime`. Careful that this could happen if the user sets it
|
|
|
|
|
// manually, or if multiple versions of `pyo3-twisted` are used!
|
|
|
|
|
let runtime: Bound<PyTokioRuntime> = reactor.getattr(TOKIO_RUNTIME_ATTR)?.extract()?;
|
|
|
|
|
Ok(runtime.borrow())
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// A reference to the `twisted.internet.defer` module.
|
|
|
|
|
static DEFER: OnceCell<PyObject> = OnceCell::new();
|
|
|
|
|
|
|
|
|
|
/// Access to the `twisted.internet.defer` module.
|
|
|
|
|
fn defer(py: Python<'_>) -> PyResult<&Bound<PyAny>> {
|
|
|
|
|
Ok(DEFER
|
|
|
|
|
.get_or_try_init(|| py.import("twisted.internet.defer").map(Into::into))?
|
|
|
|
|
.bind(py))
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Called when registering modules with python.
|
|
|
|
|
pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
|
|
|
|
|
let child_module: Bound<'_, PyModule> = PyModule::new(py, "http_client")?;
|
|
|
|
|
child_module.add_class::<HttpClient>()?;
|
|
|
|
|
|
|
|
|
|
// Make sure we fail early if we can't build the lazy statics.
|
|
|
|
|
LazyLock::force(&RUNTIME);
|
|
|
|
|
LazyLock::force(&DEFERRED_CLASS);
|
|
|
|
|
// Make sure we fail early if we can't load some modules
|
|
|
|
|
defer(py)?;
|
|
|
|
|
|
|
|
|
|
m.add_submodule(&child_module)?;
|
|
|
|
|
|
|
|
|
|
// We need to manually add the module to sys.modules to make `from
|
|
|
|
|
// synapse.synapse_rust import acl` work.
|
|
|
|
|
// synapse.synapse_rust import http_client` work.
|
|
|
|
|
py.import("sys")?
|
|
|
|
|
.getattr("modules")?
|
|
|
|
|
.set_item("synapse.synapse_rust.http_client", child_module)?;
|
|
|
|
|
@@ -72,26 +163,24 @@ pub fn register_module(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()>
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[pyclass]
|
|
|
|
|
#[derive(Clone)]
|
|
|
|
|
struct HttpClient {
|
|
|
|
|
client: reqwest::Client,
|
|
|
|
|
reactor: PyObject,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[pymethods]
|
|
|
|
|
impl HttpClient {
|
|
|
|
|
#[new]
|
|
|
|
|
pub fn py_new(user_agent: &str) -> PyResult<HttpClient> {
|
|
|
|
|
// The twisted reactor can only be imported after Synapse has been
|
|
|
|
|
// imported, to allow Synapse to change the twisted reactor. If we try
|
|
|
|
|
// and import the reactor too early twisted installs a default reactor,
|
|
|
|
|
// which can't be replaced.
|
|
|
|
|
LazyLock::force(&TWISTED_REACTOR);
|
|
|
|
|
pub fn py_new(reactor: Bound<PyAny>, user_agent: &str) -> PyResult<HttpClient> {
|
|
|
|
|
// Make sure the runtime gets installed
|
|
|
|
|
let _ = runtime(&reactor)?;
|
|
|
|
|
|
|
|
|
|
Ok(HttpClient {
|
|
|
|
|
client: reqwest::Client::builder()
|
|
|
|
|
.user_agent(user_agent)
|
|
|
|
|
.build()
|
|
|
|
|
.context("building reqwest client")?,
|
|
|
|
|
reactor: reactor.unbind(),
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -129,7 +218,7 @@ impl HttpClient {
|
|
|
|
|
builder: RequestBuilder,
|
|
|
|
|
response_limit: usize,
|
|
|
|
|
) -> PyResult<Bound<'a, PyAny>> {
|
|
|
|
|
create_deferred(py, async move {
|
|
|
|
|
create_deferred(py, self.reactor.bind(py), async move {
|
|
|
|
|
let response = builder.send().await.context("sending request")?;
|
|
|
|
|
|
|
|
|
|
let status = response.status();
|
|
|
|
|
@@ -159,43 +248,51 @@ impl HttpClient {
|
|
|
|
|
/// tokio runtime.
|
|
|
|
|
///
|
|
|
|
|
/// Does not handle deferred cancellation or contextvars.
|
|
|
|
|
fn create_deferred<F, O>(py: Python, fut: F) -> PyResult<Bound<'_, PyAny>>
|
|
|
|
|
fn create_deferred<'py, F, O>(
|
|
|
|
|
py: Python<'py>,
|
|
|
|
|
reactor: &Bound<'py, PyAny>,
|
|
|
|
|
fut: F,
|
|
|
|
|
) -> PyResult<Bound<'py, PyAny>>
|
|
|
|
|
where
|
|
|
|
|
F: Future<Output = PyResult<O>> + Send + 'static,
|
|
|
|
|
for<'a> O: IntoPyObject<'a>,
|
|
|
|
|
for<'a> O: IntoPyObject<'a> + Send + 'static,
|
|
|
|
|
{
|
|
|
|
|
let deferred = DEFERRED_CLASS.bind(py).call0()?;
|
|
|
|
|
let deferred = defer(py)?.call_method0("Deferred")?;
|
|
|
|
|
let deferred_callback = deferred.getattr("callback")?.unbind();
|
|
|
|
|
let deferred_errback = deferred.getattr("errback")?.unbind();
|
|
|
|
|
|
|
|
|
|
RUNTIME.spawn(async move {
|
|
|
|
|
// TODO: Is it safe to assert unwind safety here? I think so, as we
|
|
|
|
|
// don't use anything that could be tainted by the panic afterwards.
|
|
|
|
|
// Note that `.spawn(..)` asserts unwind safety on the future too.
|
|
|
|
|
let res = AssertUnwindSafe(fut).catch_unwind().await;
|
|
|
|
|
let rt = runtime(reactor)?;
|
|
|
|
|
let handle = rt.handle()?;
|
|
|
|
|
let task = handle.spawn(fut);
|
|
|
|
|
|
|
|
|
|
// Unbind the reactor so that we can pass it to the task
|
|
|
|
|
let reactor = reactor.clone().unbind();
|
|
|
|
|
handle.spawn(async move {
|
|
|
|
|
let res = task.await;
|
|
|
|
|
|
|
|
|
|
Python::with_gil(move |py| {
|
|
|
|
|
// Flatten the panic into standard python error
|
|
|
|
|
let res = match res {
|
|
|
|
|
Ok(r) => r,
|
|
|
|
|
Err(panic_err) => {
|
|
|
|
|
let panic_message = get_panic_message(&panic_err);
|
|
|
|
|
Err(PyException::new_err(
|
|
|
|
|
PyString::new(py, panic_message).unbind(),
|
|
|
|
|
))
|
|
|
|
|
}
|
|
|
|
|
Err(join_err) => match join_err.try_into_panic() {
|
|
|
|
|
Ok(panic_err) => Err(RustPanicError::from_panic(&panic_err)),
|
|
|
|
|
Err(err) => Err(PyException::new_err(format!("Task cancelled: {err}"))),
|
|
|
|
|
},
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
// Re-bind the reactor
|
|
|
|
|
let reactor = reactor.bind(py);
|
|
|
|
|
|
|
|
|
|
// Send the result to the deferred, via `.callback(..)` or `.errback(..)`
|
|
|
|
|
match res {
|
|
|
|
|
Ok(obj) => {
|
|
|
|
|
TWISTED_REACTOR
|
|
|
|
|
.call_method(py, "callFromThread", (deferred_callback, obj), None)
|
|
|
|
|
reactor
|
|
|
|
|
.call_method("callFromThread", (deferred_callback, obj), None)
|
|
|
|
|
.expect("callFromThread should not fail"); // There's nothing we can really do with errors here
|
|
|
|
|
}
|
|
|
|
|
Err(err) => {
|
|
|
|
|
TWISTED_REACTOR
|
|
|
|
|
.call_method(py, "callFromThread", (deferred_errback, err), None)
|
|
|
|
|
reactor
|
|
|
|
|
.call_method("callFromThread", (deferred_errback, err), None)
|
|
|
|
|
.expect("callFromThread should not fail"); // There's nothing we can really do with errors here
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
@@ -204,15 +301,3 @@ where
|
|
|
|
|
|
|
|
|
|
Ok(deferred)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/// Try and get the panic message out of the panic
|
|
|
|
|
fn get_panic_message<'a>(panic_err: &'a (dyn std::any::Any + Send + 'static)) -> &'a str {
|
|
|
|
|
// Apparently this is how you extract the panic message from a panic
|
|
|
|
|
if let Some(str_slice) = panic_err.downcast_ref::<&str>() {
|
|
|
|
|
str_slice
|
|
|
|
|
} else if let Some(string) = panic_err.downcast_ref::<String>() {
|
|
|
|
|
string
|
|
|
|
|
} else {
|
|
|
|
|
"unknown error"
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|