From fcaf248dfc8afb33a806b1a51662ad0f4ae6ae0c Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Wed, 17 Apr 2024 15:28:20 -0700 Subject: [PATCH 1/2] feat(server-common)!: remove Clone for Config --- server-common/src/config.rs | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/server-common/src/config.rs b/server-common/src/config.rs index 19f2c5f655..d90f3393d9 100644 --- a/server-common/src/config.rs +++ b/server-common/src/config.rs @@ -245,32 +245,6 @@ impl Config { } } -impl Clone for Config -where - ServerType: Server, - AcceptorType: Acceptor + Clone, -{ - fn clone(&self) -> Self { - if self.has_binding() { - eprintln!("cloning a Config with a pre-bound listener will not clone the listener. this may be a panic in the future."); - } - - Self { - acceptor: self.acceptor.clone(), - port: self.port, - host: self.host.clone(), - server: PhantomData, - nodelay: self.nodelay, - swansong: self.swansong.clone(), - register_signals: self.register_signals, - max_connections: self.max_connections, - info: AsyncCell::shared(), - binding: RwLock::new(None), - http_config: self.http_config, - } - } -} - impl Default for Config { fn default() -> Self { #[cfg(unix)] From 257b7d14eb664c78f1be0435c4b7fc6ecc6e32b5 Mon Sep 17 00:00:00 2001 From: Jacob Rothstein Date: Wed, 17 Apr 2024 19:23:51 -0700 Subject: [PATCH 2/2] feat!: introduce Runtime --- async-std/Cargo.toml | 3 +- async-std/src/client.rs | 16 +- async-std/src/lib.rs | 8 +- async-std/src/runtime.rs | 86 ++++++++++ async-std/src/server.rs | 1 - async-std/src/server/tcp.rs | 16 +- async-std/src/server/unix.rs | 15 +- client/src/conn.rs | 62 ++++--- client/tests/base.rs | 35 +++- client/tests/one_hundred_continue.rs | 20 +-- client/tests/timeout.rs | 11 +- client/tests/websocket.rs | 6 +- http/tests/corpus.rs | 5 +- http/tests/one_hundred_continue.rs | 21 +-- http/tests/unsafe_headers.rs | 11 +- http/tests/use_cases.rs | 17 +- native-tls/src/client.rs | 10 +- rustls/src/client.rs | 10 +- server-common/Cargo.toml | 1 + server-common/src/client.rs | 62 ++++--- server-common/src/config.rs | 17 +- server-common/src/config_ext.rs | 23 +-- server-common/src/lib.rs | 5 +- server-common/src/runtime.rs | 120 ++++++++++++++ server-common/src/runtime/droppable_future.rs | 41 +++++ .../src/runtime/object_safe_runtime.rs | 57 +++++++ server-common/src/runtime/runtime_trait.rs | 54 ++++++ server-common/src/server.rs | 30 ++-- server-common/src/server_handle.rs | 24 ++- smol/Cargo.toml | 6 +- smol/examples/smol.rs | 29 ++-- smol/src/client.rs | 27 ++- smol/src/lib.rs | 10 +- smol/src/runtime.rs | 110 ++++++++++++ smol/src/server.rs | 4 +- smol/src/server/tcp.rs | 33 ++-- smol/src/server/unix.rs | 27 +-- smol/src/transport.rs | 7 + smol/tests/unix_stream.rs | 21 +++ testing/Cargo.toml | 2 +- testing/src/lib.rs | 148 ++++------------- testing/src/runtimeless.rs | 156 +----------------- testing/src/runtimeless/client.rs | 36 ++++ testing/src/runtimeless/runtime.rs | 112 +++++++++++++ testing/src/runtimeless/server.rs | 84 ++++++++++ testing/src/server_connector.rs | 97 ++--------- testing/src/with_server.rs | 20 ++- testing/tests/runtimeless.rs | 36 ++++ testing/tests/server_connector.rs | 54 ++++++ testing/tests/spawn.rs | 35 ++-- tokio/Cargo.toml | 3 +- tokio/src/client.rs | 12 +- tokio/src/lib.rs | 15 +- tokio/src/runtime.rs | 117 +++++++++++++ tokio/src/server.rs | 1 - tokio/src/server/tcp.rs | 22 +-- tokio/src/server/unix.rs | 23 +-- tokio/tests/tests.rs | 42 +++-- trillium/tests/liveness.rs | 10 +- 59 files changed, 1390 insertions(+), 696 deletions(-) create mode 100644 async-std/src/runtime.rs create mode 100644 server-common/src/runtime.rs create mode 100644 server-common/src/runtime/droppable_future.rs create mode 100644 server-common/src/runtime/object_safe_runtime.rs create mode 100644 server-common/src/runtime/runtime_trait.rs create mode 100644 smol/src/runtime.rs create mode 100644 smol/tests/unix_stream.rs create mode 100644 testing/src/runtimeless/client.rs create mode 100644 testing/src/runtimeless/runtime.rs create mode 100644 testing/src/runtimeless/server.rs create mode 100644 testing/tests/runtimeless.rs create mode 100644 testing/tests/server_connector.rs create mode 100644 tokio/src/runtime.rs diff --git a/async-std/Cargo.toml b/async-std/Cargo.toml index 22696c9918..657ad34e68 100644 --- a/async-std/Cargo.toml +++ b/async-std/Cargo.toml @@ -11,7 +11,8 @@ keywords = ["trillium", "framework", "async"] categories = ["web-programming::http-server", "web-programming"] [dependencies] -async-std = "1.12.0" +async-std = { version = "1.12.0", features = ["unstable"] } +futures-lite = "2.3.0" log = "0.4.20" trillium = { path = "../trillium", version = "0.2.19" } trillium-http = { path = "../http", version = "0.3.16" } diff --git a/async-std/src/client.rs b/async-std/src/client.rs index 5d6808661c..5db6e4e50c 100644 --- a/async-std/src/client.rs +++ b/async-std/src/client.rs @@ -1,9 +1,6 @@ -use crate::AsyncStdTransport; +use crate::{AsyncStdRuntime, AsyncStdTransport}; use async_std::net::TcpStream; -use std::{ - future::Future, - io::{Error, ErrorKind, Result}, -}; +use std::io::{Error, ErrorKind, Result}; use trillium_server_common::{ url::{Host, Url}, Connector, Transport, @@ -45,6 +42,7 @@ impl ClientConfig { impl Connector for ClientConfig { type Transport = AsyncStdTransport; + type Runtime = AsyncStdRuntime; async fn connect(&self, url: &Url) -> Result { if url.scheme() != "http" { @@ -80,11 +78,7 @@ impl Connector for ClientConfig { Ok(tcp) } - fn spawn + Send + 'static>(&self, fut: Fut) { - async_std::task::spawn(fut); - } - - async fn delay(&self, duration: std::time::Duration) { - let _ = async_std::future::timeout(duration, std::future::pending::<()>()).await; + fn runtime(&self) -> Self::Runtime { + AsyncStdRuntime::default() } } diff --git a/async-std/src/lib.rs b/async-std/src/lib.rs index c07dff746f..8c3c5a36e4 100644 --- a/async-std/src/lib.rs +++ b/async-std/src/lib.rs @@ -32,8 +32,6 @@ async fn main() { ``` */ -use std::future::Future; - use trillium::Handler; pub use trillium_server_common::{Binding, Swansong}; @@ -113,7 +111,5 @@ pub fn config() -> Config<()> { Config::new() } -/// spawn and detach a Future that returns () -pub fn spawn + Send + 'static>(future: Fut) { - async_std::task::spawn(future); -} +mod runtime; +pub use runtime::AsyncStdRuntime; diff --git a/async-std/src/runtime.rs b/async-std/src/runtime.rs new file mode 100644 index 0000000000..a4f1f19207 --- /dev/null +++ b/async-std/src/runtime.rs @@ -0,0 +1,86 @@ +use futures_lite::future::FutureExt; +use std::{future::Future, time::Duration}; +use trillium_server_common::{DroppableFuture, Runtime, RuntimeTrait, Stream}; + +/// async-std runtime +#[derive(Clone, Copy, Default, Debug)] +pub struct AsyncStdRuntime(()); + +impl RuntimeTrait for AsyncStdRuntime { + fn spawn( + &self, + fut: Fut, + ) -> DroppableFuture> + Send + 'static> + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let join_handle = async_std::task::spawn(fut); + DroppableFuture::new(async move { join_handle.catch_unwind().await.ok() }) + } + + async fn delay(&self, duration: Duration) { + async_std::task::sleep(duration).await + } + + fn interval(&self, period: Duration) -> impl Stream + Send + 'static { + async_std::stream::interval(period) + } + + fn block_on(&self, fut: Fut) -> Fut::Output { + async_std::task::block_on(fut) + } +} + +impl AsyncStdRuntime { + /// Spawn a future on the runtime, returning a future that has detach-on-drop semantics + /// + /// Spawned tasks conform to the following behavior: + /// + /// * detach on drop: If the returned [`DroppableFuture`] is dropped immediately, the task will + /// continue to execute until completion. + /// + /// * unwinding: If the spawned future panics, this must not propagate to the join + /// handle. Instead, the awaiting the join handle returns None in case of panic. + pub fn spawn( + &self, + fut: Fut, + ) -> DroppableFuture> + Send + 'static> + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let join_handle = async_std::task::spawn(fut); + DroppableFuture::new(async move { join_handle.catch_unwind().await.ok() }) + } + + /// Wake in this amount of wall time + pub async fn delay(&self, duration: Duration) { + async_std::task::sleep(duration).await + } + + /// Returns a [`Stream`] that yields a `()` on the provided period + pub fn interval(&self, period: Duration) -> impl Stream + Send + 'static { + async_std::stream::interval(period) + } + + /// Runtime implementation hook for blocking on a top level future. + pub fn block_on(&self, fut: Fut) -> Fut::Output { + async_std::task::block_on(fut) + } + + /// Race a future against the provided duration, returning None in case of timeout. + pub async fn timeout(&self, duration: Duration, fut: Fut) -> Option + where + Fut: Future + Send, + Fut::Output: Send + 'static, + { + RuntimeTrait::timeout(self, duration, fut).await + } +} + +impl From for Runtime { + fn from(value: AsyncStdRuntime) -> Self { + Runtime::new(value) + } +} diff --git a/async-std/src/server.rs b/async-std/src/server.rs index 5771c659ce..b7323e64c4 100644 --- a/async-std/src/server.rs +++ b/async-std/src/server.rs @@ -3,7 +3,6 @@ mod unix; #[cfg(unix)] pub use unix::AsyncStdServer; -#[cfg(not(unix))] mod tcp; #[cfg(not(unix))] pub use tcp::AsyncStdServer; diff --git a/async-std/src/server/tcp.rs b/async-std/src/server/tcp.rs index a45d3dfeee..aff2e294d5 100644 --- a/async-std/src/server/tcp.rs +++ b/async-std/src/server/tcp.rs @@ -1,7 +1,6 @@ -use crate::AsyncStdTransport; +use crate::{AsyncStdRuntime, AsyncStdTransport}; use async_std::net::{TcpListener, TcpStream}; -use async_std::task::{block_on, spawn}; -use std::{convert::TryInto, env, future::Future, io::Result}; +use std::{env, io::Result}; use trillium::Info; use trillium_server_common::Server; @@ -20,6 +19,7 @@ impl From for AsyncStdServer { } impl Server for AsyncStdServer { + type Runtime = AsyncStdRuntime; type Transport = AsyncStdTransport; const DESCRIPTION: &'static str = concat!( " (", @@ -34,18 +34,14 @@ impl Server for AsyncStdServer { } fn listener_from_tcp(tcp: std::net::TcpListener) -> Self { - Self(tcp.try_into().unwrap()) + Self(tcp.into()) } fn info(&self) -> Info { self.0.local_addr().unwrap().into() } - fn spawn(fut: impl Future + Send + 'static) { - spawn(fut); - } - - fn block_on(fut: impl Future + 'static) { - block_on(fut) + fn runtime() -> Self::Runtime { + AsyncStdRuntime::default() } } diff --git a/async-std/src/server/unix.rs b/async-std/src/server/unix.rs index 86c36ed668..0fad8fab9d 100644 --- a/async-std/src/server/unix.rs +++ b/async-std/src/server/unix.rs @@ -1,11 +1,10 @@ -use crate::AsyncStdTransport; +use crate::{AsyncStdRuntime, AsyncStdTransport}; use async_std::{ net::{TcpListener, TcpStream}, os::unix::net::{UnixListener, UnixStream}, stream::StreamExt, - task::{block_on, spawn}, }; -use std::{env, future::Future, io::Result}; +use std::{env, io::Result}; use trillium::{log_error, Info}; use trillium_server_common::{ Binding::{self, *}, @@ -39,6 +38,8 @@ impl From for AsyncStdServer { #[cfg(unix)] impl Server for AsyncStdServer { + type Runtime = AsyncStdRuntime; + type Transport = Binding, AsyncStdTransport>; const DESCRIPTION: &'static str = concat!( " (", @@ -94,12 +95,8 @@ impl Server for AsyncStdServer { } } - fn spawn(fut: impl Future + Send + 'static) { - spawn(fut); - } - - fn block_on(fut: impl Future + 'static) { - block_on(fut); + fn runtime() -> Self::Runtime { + AsyncStdRuntime::default() } async fn clean_up(self) { diff --git a/client/src/conn.rs b/client/src/conn.rs index 26a4a401fb..99fea0d60a 100644 --- a/client/src/conn.rs +++ b/client/src/conn.rs @@ -1,6 +1,6 @@ use crate::{pool::PoolEntry, util::encoding, Pool}; use encoding_rs::Encoding; -use futures_lite::{future::poll_once, io, AsyncReadExt, AsyncWriteExt, FutureExt}; +use futures_lite::{future::poll_once, io, AsyncReadExt, AsyncWriteExt}; use memchr::memmem::Finder; use size::{Base, Size}; use std::{ @@ -94,7 +94,7 @@ impl Conn { chainable setter for [`inserting`](Headers::insert) a request header ``` - use trillium_testing::ClientConfig; + use trillium_testing::client_config; let handler = |conn: trillium::Conn| async move { @@ -103,7 +103,7 @@ impl Conn { conn.ok(response) }; - let client = trillium_client::Client::new(ClientConfig::new()); + let client = trillium_client::Client::new(client_config()); trillium_testing::with_server(handler, |url| async move { let mut conn = client.get(url) @@ -137,8 +137,8 @@ impl Conn { conn.ok(response) }; - use trillium_testing::ClientConfig; - let client = trillium_client::client(ClientConfig::new()); + use trillium_testing::client_config; + let client = trillium_client::client(client_config()); trillium_testing::with_server(handler, move |url| async move { let mut conn = client.get(url) @@ -179,10 +179,10 @@ impl Conn { }; use trillium_client::Client; - use trillium_testing::ClientConfig; + use trillium_testing::client_config; trillium_testing::with_server(handler, move |url| async move { - let client = Client::new(ClientConfig::new()); + let client = Client::new(client_config()); let conn = client.get(url).await?; let headers = conn.response_headers(); //<- @@ -202,7 +202,7 @@ impl Conn { Conn ``` - use trillium_testing::ClientConfig; + use trillium_testing::client_config; use trillium_client::Client; let handler = |conn: trillium::Conn| async move { @@ -211,7 +211,7 @@ impl Conn { conn.ok(response) }; - let client = Client::new(ClientConfig::new()); + let client = Client::new(client_config()); trillium_testing::with_server(handler, move |url| async move { let mut conn = client.get(url); @@ -245,7 +245,7 @@ impl Conn { ``` env_logger::init(); use trillium_client::Client; - use trillium_testing::ClientConfig; + use trillium_testing::client_config; let handler = |mut conn: trillium::Conn| async move { @@ -254,7 +254,7 @@ impl Conn { }; trillium_testing::with_server(handler, move |url| async move { - let client = Client::new(ClientConfig::new()); + let client = Client::new(client_config()); let mut conn = client.post(url); conn.set_request_body("body"); //<- @@ -275,7 +275,7 @@ impl Conn { ``` env_logger::init(); - use trillium_testing::ClientConfig; + use trillium_testing::client_config; use trillium_client::Client; let handler = |mut conn: trillium::Conn| async move { @@ -285,7 +285,7 @@ impl Conn { trillium_testing::with_server(handler, |url| async move { - let client = Client::from(ClientConfig::default()); + let client = Client::from(client_config()); let mut conn = client.post(url) .with_body("body") //<- .await?; @@ -322,9 +322,9 @@ impl Conn { /** retrieves the url for this conn. ``` - use trillium_testing::ClientConfig; + use trillium_testing::client_config; use trillium_client::Client; - let client = Client::from(ClientConfig::new()); + let client = Client::from(client_config()); let conn = client.get("http://localhost:9080"); let url = conn.url(); //<- @@ -339,12 +339,12 @@ impl Conn { /** retrieves the url for this conn. ``` - use trillium_testing::ClientConfig; + use trillium_testing::client_config; use trillium_client::Client; use trillium_testing::prelude::*; - let client = Client::from(ClientConfig::new()); + let client = Client::from(client_config()); let conn = client.get("http://localhost:9080"); let method = conn.method(); //<- @@ -360,7 +360,7 @@ impl Conn { returns a [`ReceivedBody`] that borrows the connection inside this conn. ``` env_logger::init(); - use trillium_testing::ClientConfig; + use trillium_testing::client_config; use trillium_client::Client; @@ -370,7 +370,7 @@ impl Conn { }; trillium_testing::with_server(handler, |url| async move { - let client = Client::from(ClientConfig::new()); + let client = Client::from(client_config()); let mut conn = client.get(url).await?; let response_body = conn.response_body(); //<- @@ -425,7 +425,7 @@ impl Conn { been sent, this will be None. ``` - use trillium_testing::ClientConfig; + use trillium_testing::client_config; use trillium_client::Client; use trillium_testing::prelude::*; @@ -434,7 +434,7 @@ impl Conn { } trillium_testing::with_server(handler, |url| async move { - let client = Client::new(ClientConfig::new()); + let client = Client::new(client_config()); let conn = client.get(url).await?; assert_eq!(Status::ImATeapot, conn.status().unwrap()); Ok(()) @@ -449,10 +449,10 @@ impl Conn { Returns the conn or an [`UnexpectedStatusError`] that contains the conn ``` - use trillium_testing::ClientConfig; + use trillium_testing::client_config; trillium_testing::with_server(trillium::Status::NotFound, |url| async move { - let client = trillium_client::Client::new(ClientConfig::new()); + let client = trillium_client::Client::new(client_config()); assert_eq!( client.get(url).await?.success().unwrap_err().to_string(), "expected a success (2xx) status code, but got 404 Not Found" @@ -461,7 +461,7 @@ impl Conn { }); trillium_testing::with_server(trillium::Status::Ok, |url| async move { - let client = trillium_client::Client::new(ClientConfig::new()); + let client = trillium_client::Client::new(client_config()); assert!(client.get(url).await?.success().is_ok()); Ok(()) }); @@ -870,7 +870,7 @@ impl Drop for Conn { let buffer = std::mem::take(&mut self.buffer); let response_body_state = self.response_body_state; let encoding = encoding(&self.response_headers); - Connector::spawn(&self.config, async move { + self.config.runtime().spawn(async move { let mut response_body = ReceivedBody::new( content_length, buffer, @@ -950,13 +950,11 @@ impl IntoFuture for Conn { fn into_future(mut self) -> Self::IntoFuture { Box::pin(async move { if let Some(duration) = self.timeout { - let config = self.config.clone(); - self.exec() - .or(async { - config.delay(duration).await; - Err(Error::TimedOut("Conn", duration)) - }) - .await? + self.config + .runtime() + .timeout(duration, self.exec()) + .await + .ok_or(Error::TimedOut("Conn", duration))??; } else { self.exec().await?; } diff --git a/client/tests/base.rs b/client/tests/base.rs index 0373a454f2..35f40d5bb8 100644 --- a/client/tests/base.rs +++ b/client/tests/base.rs @@ -1,4 +1,7 @@ -use std::str::FromStr; +use std::{ + net::{IpAddr, SocketAddr}, + str::FromStr, +}; use test_harness::test; use trillium_client::{Client, Status}; use trillium_testing::{harness, ServerConnector, TestResult, Url}; @@ -83,6 +86,36 @@ async fn without_base() -> TestResult { .build_url(Url::from_str("data:text/plain,Stuff")?) .is_err()); + assert_eq!( + client + .build_url(IpAddr::from_str("127.0.0.1").unwrap()) + .unwrap() + .as_str(), + "http://127.0.0.1/" + ); + assert_eq!( + client + .build_url(IpAddr::from_str("::1").unwrap()) + .unwrap() + .as_str(), + "http://[::1]/" + ); + + assert_eq!( + client + .build_url(SocketAddr::from_str("127.0.0.1:8080").unwrap()) + .unwrap() + .as_str(), + "http://127.0.0.1:8080/" + ); + assert_eq!( + client + .build_url(SocketAddr::from_str("[::1]:8080").unwrap()) + .unwrap() + .as_str(), + "http://[::1]:8080/" + ); + Ok(()) } diff --git a/client/tests/one_hundred_continue.rs b/client/tests/one_hundred_continue.rs index 4337211fb5..75acd609c1 100644 --- a/client/tests/one_hundred_continue.rs +++ b/client/tests/one_hundred_continue.rs @@ -6,7 +6,7 @@ use std::future::{Future, IntoFuture}; use test_harness::test; use trillium_client::{Client, Conn, Error, Status, USER_AGENT}; use trillium_server_common::{Connector, Url}; -use trillium_testing::{harness, TestResult, TestTransport}; +use trillium_testing::{harness, RuntimeTrait, TestResult, TestTransport}; #[test(harness)] async fn extra_one_hundred_continue() -> TestResult { @@ -216,22 +216,19 @@ async fn little_continue_big_continue() -> TestResult { const TEST_DATE: &str = "Tue, 21 Nov 2023 21:27:21 GMT"; -struct TestConnector(Sender); +struct TestConnector(Sender, R); -impl Connector for TestConnector { +impl Connector for TestConnector { type Transport = TestTransport; + type Runtime = R; async fn connect(&self, _url: &Url) -> std::io::Result { let (server, client) = TestTransport::new(); let _ = self.0.send(server).await; Ok(client) } - fn spawn + Send + 'static>(&self, fut: Fut) { - let _ = trillium_testing::spawn(fut); - } - - async fn delay(&self, duration: std::time::Duration) { - trillium_testing::delay(duration).await + fn runtime(&self) -> Self::Runtime { + self.1.clone() } } @@ -239,8 +236,9 @@ async fn test_conn( setup: impl FnOnce(Client) -> Conn + Send + 'static, ) -> (TestTransport, impl Future>) { let (sender, receiver) = async_channel::unbounded(); - let client = Client::new(TestConnector(sender)); - let conn_fut = trillium_testing::spawn(setup(client).into_future()); + let client = Client::new(TestConnector(sender, trillium_testing::runtime())); + let runtime = client.connector().runtime(); + let conn_fut = runtime.spawn(setup(client).into_future()).into_future(); let transport = receiver.recv().await.unwrap(); (transport, async move { conn_fut.await.unwrap() }) } diff --git a/client/tests/timeout.rs b/client/tests/timeout.rs index 9a5c24cd8f..a46ce108ce 100644 --- a/client/tests/timeout.rs +++ b/client/tests/timeout.rs @@ -1,18 +1,19 @@ use std::time::Duration; use trillium_client::Client; -use trillium_testing::ClientConfig; +use trillium_testing::{client_config, runtime, RuntimeTrait}; async fn handler(conn: trillium::Conn) -> trillium::Conn { if conn.path() == "/slow" { - trillium_testing::delay(Duration::from_secs(5)).await; + runtime().delay(Duration::from_secs(5)).await; } conn.ok("ok") } #[test] fn timeout_on_conn() { + let _ = env_logger::builder().is_test(true).try_init(); trillium_testing::with_server(handler, move |url| async move { - let client = Client::new(ClientConfig::new()).with_base(url); + let client = Client::new(client_config()).with_base(url); let err = client .get("/slow") .with_timeout(Duration::from_millis(100)) @@ -33,8 +34,10 @@ fn timeout_on_conn() { #[test] fn timeout_on_client() { + let _ = env_logger::builder().is_test(true).try_init(); + trillium_testing::with_server(handler, move |url| async move { - let client = Client::new(ClientConfig::new()) + let client = Client::new(client_config()) .with_base(url) .with_timeout(Duration::from_millis(100)); let err = client.get("/slow").await.unwrap_err(); diff --git a/client/tests/websocket.rs b/client/tests/websocket.rs index 0052563ada..42ec9f935f 100644 --- a/client/tests/websocket.rs +++ b/client/tests/websocket.rs @@ -4,7 +4,7 @@ use trillium_client::{ Client, WebSocketConn, }; use trillium_http::Status; -use trillium_testing::ClientConfig; +use trillium_testing::client_config; use trillium_websockets::websocket; #[test] @@ -17,7 +17,7 @@ fn test_websockets() { } }); - let client = Client::new(ClientConfig::new()); + let client = Client::new(client_config()); trillium_testing::with_server(handler, move |url| async move { let mut ws = client.get(url).into_websocket().await?; @@ -39,7 +39,7 @@ fn test_websockets() { fn test_websockets_error() { let handler = |conn: trillium::Conn| async { conn.with_status(404).with_body("This does not exist") }; - let client = Client::new(ClientConfig::new()); + let client = Client::new(client_config()); trillium_testing::with_server(handler, move |url| async move { let err = client .get(url) diff --git a/http/tests/corpus.rs b/http/tests/corpus.rs index c410f752ab..9e17be2c70 100644 --- a/http/tests/corpus.rs +++ b/http/tests/corpus.rs @@ -2,7 +2,7 @@ use indoc::formatdoc; use pretty_assertions::assert_eq; use test_harness::test; use trillium_http::{Conn, KnownHeaderName, Swansong}; -use trillium_testing::{harness, TestTransport}; +use trillium_testing::{harness, RuntimeTrait, TestTransport}; const TEST_DATE: &str = "Tue, 21 Nov 2023 21:27:21 GMT"; async fn handler(mut conn: Conn) -> Conn { @@ -44,6 +44,7 @@ async fn handler(mut conn: Conn) -> Conn { #[test(harness)] async fn corpus_test() { env_logger::init(); + let runtime = trillium_testing::runtime(); let dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/corpus"); let filter = std::env::var("CORPUS_TEST_FILTER").unwrap_or_default(); let corpus_request_files = std::fs::read_dir(dir) @@ -67,7 +68,7 @@ async fn corpus_test() { let (client, server) = TestTransport::new(); let swansong = Swansong::new(); - let res = trillium_testing::spawn({ + let res = runtime.spawn({ let swansong = swansong.clone(); async move { Conn::map(server, swansong, handler).await } }); diff --git a/http/tests/one_hundred_continue.rs b/http/tests/one_hundred_continue.rs index b2c6ac4ccd..87f8bb869f 100644 --- a/http/tests/one_hundred_continue.rs +++ b/http/tests/one_hundred_continue.rs @@ -1,9 +1,8 @@ use indoc::{formatdoc, indoc}; use pretty_assertions::assert_eq; -use swansong::Swansong; use test_harness::test; -use trillium_http::{Conn, KnownHeaderName, SERVER}; -use trillium_testing::{harness, TestResult, TestTransport}; +use trillium_http::{Conn, KnownHeaderName, Swansong, SERVER}; +use trillium_testing::{harness, RuntimeTrait, TestResult, TestTransport}; const TEST_DATE: &str = "Tue, 21 Nov 2023 21:27:21 GMT"; @@ -21,10 +20,8 @@ async fn handler(mut conn: Conn) -> Conn { #[test(harness)] async fn one_hundred_continue() -> TestResult { let (client, server) = TestTransport::new(); - - trillium_testing::spawn(async move { - Conn::map(server, Swansong::new(), handler).await.unwrap(); - }); + let runtime = trillium_testing::runtime(); + let handle = runtime.spawn(async move { Conn::map(server, Swansong::new(), handler).await }); client.write_all(indoc! {" POST / HTTP/1.1\r @@ -52,17 +49,15 @@ async fn one_hundred_continue() -> TestResult { "}; assert_eq!(client.read_available_string().await, expected_response); - + handle.await.unwrap().unwrap(); Ok(()) } #[test(harness)] async fn one_hundred_continue_http_one_dot_zero() -> TestResult { let (client, server) = TestTransport::new(); - - trillium_testing::spawn(async move { - Conn::map(server, Swansong::new(), handler).await.unwrap(); - }); + let runtime = trillium_testing::runtime(); + let handle = runtime.spawn(async move { Conn::map(server, Swansong::new(), handler).await }); client.write_all(indoc! { " POST / HTTP/1.0\r @@ -85,6 +80,6 @@ async fn one_hundred_continue_http_one_dot_zero() -> TestResult { "}; assert_eq!(client.read_available_string().await, expected_response); - + handle.await.unwrap().unwrap(); Ok(()) } diff --git a/http/tests/unsafe_headers.rs b/http/tests/unsafe_headers.rs index 1b7029c767..b4c3187e5f 100644 --- a/http/tests/unsafe_headers.rs +++ b/http/tests/unsafe_headers.rs @@ -3,7 +3,7 @@ use pretty_assertions::assert_eq; use swansong::Swansong; use test_harness::test; use trillium_http::{Conn, KnownHeaderName, SERVER}; -use trillium_testing::{harness, TestResult, TestTransport}; +use trillium_testing::{harness, RuntimeTrait, TestResult, TestTransport}; const TEST_DATE: &str = "Tue, 21 Nov 2023 21:27:21 GMT"; @@ -23,14 +23,13 @@ async fn handler(mut conn: Conn) -> Conn { #[test(harness)] async fn bad_headers() -> TestResult { let (client, server) = TestTransport::new(); - - trillium_testing::spawn(async move { - Conn::map(server, Swansong::new(), handler).await.unwrap(); - }); + let runtime = trillium_testing::runtime(); + let handle = runtime.spawn(async move { Conn::map(server, Swansong::new(), handler).await }); client.write_all(indoc! {" GET / HTTP/1.1\r Host: example.com\r + Connection: close\r \r "}); @@ -45,5 +44,7 @@ async fn bad_headers() -> TestResult { assert_eq!(client.read_available_string().await, expected_response); + handle.await.unwrap().unwrap(); + Ok(()) } diff --git a/http/tests/use_cases.rs b/http/tests/use_cases.rs index a6a23a059d..8ef194ed3d 100644 --- a/http/tests/use_cases.rs +++ b/http/tests/use_cases.rs @@ -4,9 +4,9 @@ use std::{future::Future, marker::PhantomData, sync::Arc}; use test_harness::test; use trillium_client::{Client, Connector, Url}; use trillium_http::{Conn, KnownHeaderName}; -use trillium_testing::{TestResult, TestTransport}; +use trillium_testing::{harness, Runtime, TestResult, TestTransport}; -#[test(harness = trillium_testing::harness)] +#[test(harness)] async fn send_no_server_header() -> TestResult { let client = Client::new(ServerConnector::new(|mut conn| async move { conn.response_headers_mut().remove(KnownHeaderName::Server); @@ -21,6 +21,7 @@ async fn send_no_server_header() -> TestResult { pub struct ServerConnector { handler: Arc, fut: PhantomData, + runtime: Runtime, } impl ServerConnector @@ -32,6 +33,7 @@ where Self { handler: Arc::new(handler), fut: PhantomData, + runtime: trillium_testing::runtime().into(), } } } @@ -42,13 +44,14 @@ where Fut: Future> + Send + Sync + 'static, { type Transport = TestTransport; + type Runtime = Runtime; async fn connect(&self, _: &Url) -> std::io::Result { let (client_transport, server_transport) = TestTransport::new(); let handler = self.handler.clone(); - trillium_testing::spawn(async move { + self.runtime.spawn(async move { Conn::map(server_transport, Default::default(), &*handler) .await .unwrap(); @@ -57,11 +60,7 @@ where Ok(client_transport) } - fn spawn + Send + 'static>(&self, fut: SpawnFut) { - trillium_testing::spawn(fut); - } - - async fn delay(&self, duration: std::time::Duration) { - trillium_testing::delay(duration).await + fn runtime(&self) -> Self::Runtime { + self.runtime.clone() } } diff --git a/native-tls/src/client.rs b/native-tls/src/client.rs index b5094e3f60..66ecfb03dc 100644 --- a/native-tls/src/client.rs +++ b/native-tls/src/client.rs @@ -1,7 +1,6 @@ use async_native_tls::{TlsConnector, TlsStream}; use std::{ fmt::{Debug, Formatter}, - future::Future, io::{Error, ErrorKind, IoSlice, IoSliceMut, Result}, net::SocketAddr, pin::Pin, @@ -69,6 +68,7 @@ impl AsRef for NativeTlsConfig { } impl Connector for NativeTlsConfig { + type Runtime = T::Runtime; type Transport = NativeTlsClientTransport; async fn connect(&self, url: &Url) -> Result { @@ -99,12 +99,8 @@ impl Connector for NativeTlsConfig { } } - fn spawn + Send + 'static>(&self, fut: Fut) { - self.tcp_config.spawn(fut) - } - - async fn delay(&self, duration: std::time::Duration) { - self.tcp_config.delay(duration).await + fn runtime(&self) -> Self::Runtime { + self.tcp_config.runtime() } } diff --git a/rustls/src/client.rs b/rustls/src/client.rs index 271e9ff170..788d038ad5 100644 --- a/rustls/src/client.rs +++ b/rustls/src/client.rs @@ -9,7 +9,6 @@ use futures_rustls::{ }; use std::{ fmt::{self, Debug, Formatter}, - future::Future, io::{Error, ErrorKind, IoSlice, Result}, net::SocketAddr, pin::Pin, @@ -108,6 +107,7 @@ impl Debug for RustlsConfig { impl Connector for RustlsConfig { type Transport = RustlsClientTransport; + type Runtime = C::Runtime; async fn connect(&self, url: &Url) -> Result { match url.scheme() { @@ -138,12 +138,8 @@ impl Connector for RustlsConfig { } } - fn spawn + Send + 'static>(&self, fut: Fut) { - self.tcp_config.spawn(fut) - } - - async fn delay(&self, duration: std::time::Duration) { - self.tcp_config.delay(duration).await + fn runtime(&self) -> Self::Runtime { + self.tcp_config.runtime() } } diff --git a/server-common/Cargo.toml b/server-common/Cargo.toml index 3af622d3f5..f01a5c17ae 100644 --- a/server-common/Cargo.toml +++ b/server-common/Cargo.toml @@ -11,6 +11,7 @@ keywords = ["trillium", "framework", "async"] categories = ["web-programming::http-server", "web-programming"] [dependencies] +async-channel = "2.2.0" async_cell = "0.2.2" futures-lite = "2.1.0" log = "0.4.20" diff --git a/server-common/src/client.rs b/server-common/src/client.rs index af5ca864df..5a44b8377c 100644 --- a/server-common/src/client.rs +++ b/server-common/src/client.rs @@ -1,6 +1,6 @@ use trillium_http::transport::BoxedTransport; -use crate::{Transport, Url}; +use crate::{Runtime, RuntimeTrait, Transport, Url}; use std::{ any::Any, fmt::{self, Debug}, @@ -8,7 +8,6 @@ use std::{ io, pin::Pin, sync::Arc, - time::Duration, }; /** Interface for runtime and tls adapters for the trillium client @@ -21,14 +20,22 @@ pub trait Connector: Send + Sync + 'static { /// the [`Transport`] that [`connect`] returns type Transport: Transport; + /// The [`RuntimeTrait`] for this Connector + type Runtime: RuntimeTrait; + /// Initiate a connection to the provided url fn connect(&self, url: &Url) -> impl Future> + Send; - /// spawn and detach a future on the runtime - fn spawn + Send + 'static>(&self, fut: Fut); + /// Returns an object-safe [`ArcedConnector`]. Do not implement this. + fn arced(self) -> ArcedConnector + where + Self: Sized, + { + ArcedConnector(Arc::new(self)) + } - /// wake in this amount of wall time - fn delay(&self, duration: Duration) -> impl Future + Send; + /// Returns the runtime + fn runtime(&self) -> Self::Runtime; } /// An Arced and type-erased [`Connector`] @@ -44,8 +51,8 @@ impl Debug for ArcedConnector { impl ArcedConnector { /// Constructs a new `ArcedConnector` #[must_use] - pub fn new(handler: impl Connector) -> Self { - Self(Arc::new(handler)) + pub fn new(connector: impl Connector) -> Self { + connector.arced() } /// Determine if this `ArcedConnector` is the specified type @@ -64,6 +71,11 @@ impl ArcedConnector { pub fn downcast_mut(&mut self) -> Option<&mut T> { Arc::get_mut(&mut self.0)?.as_mut_any().downcast_mut() } + + /// Returns an object-safe [`Runtime`] + pub fn runtime(&self) -> Runtime { + self.0.runtime() + } } trait ObjectSafeConnector: Send + Sync + 'static { @@ -76,16 +88,9 @@ trait ObjectSafeConnector: Send + Sync + 'static { 'connector: 'fut, 'url: 'fut, Self: 'fut; - fn spawn(&self, fut: Pin + Send + 'static>>); - fn delay<'connector, 'fut>( - &'connector self, - duration: Duration, - ) -> Pin + Send + 'fut>> - where - 'connector: 'fut, - Self: 'fut; fn as_any(&self) -> &dyn Any; fn as_mut_any(&mut self) -> &mut dyn Any; + fn runtime(&self) -> Runtime; } impl ObjectSafeConnector for T { @@ -100,25 +105,14 @@ impl ObjectSafeConnector for T { { Box::pin(async move { Connector::connect(self, url).await.map(BoxedTransport::new) }) } - fn spawn(&self, fut: Pin + Send + 'static>>) { - Connector::spawn(self, fut) - } - fn as_any(&self) -> &dyn Any { self } fn as_mut_any(&mut self) -> &mut dyn Any { self } - fn delay<'connector, 'fut>( - &'connector self, - duration: Duration, - ) -> Pin + Send + 'fut>> - where - 'connector: 'fut, - Self: 'fut, - { - Box::pin(async move { Connector::delay(self, duration).await }) + fn runtime(&self) -> Runtime { + Connector::runtime(self).into() } } @@ -128,11 +122,13 @@ impl Connector for ArcedConnector { self.0.connect(url).await } - fn spawn + Send + 'static>(&self, fut: Fut) { - self.0.spawn(Box::pin(fut)) + type Runtime = Runtime; + + fn arced(self) -> ArcedConnector { + self } - async fn delay(&self, duration: Duration) { - self.0.delay(duration).await + fn runtime(&self) -> Self::Runtime { + self.0.runtime() } } diff --git a/server-common/src/config.rs b/server-common/src/config.rs index d90f3393d9..4de893e268 100644 --- a/server-common/src/config.rs +++ b/server-common/src/config.rs @@ -1,6 +1,7 @@ -use crate::{Acceptor, Server, ServerHandle}; +use crate::{Acceptor, RuntimeTrait, Server, ServerHandle}; use async_cell::sync::AsyncCell; use std::{ + cell::OnceCell, marker::PhantomData, net::SocketAddr, sync::{Arc, RwLock}, @@ -57,7 +58,7 @@ In order to use this to _implement_ a trillium server, see */ #[derive(Debug)] -pub struct Config { +pub struct Config { pub(crate) acceptor: AcceptorType, pub(crate) port: Option, pub(crate) host: Option, @@ -69,6 +70,7 @@ pub struct Config { pub(crate) binding: RwLock>, pub(crate) server: PhantomData, pub(crate) http_config: HttpConfig, + pub(crate) runtime: ServerType::Runtime, } impl Config @@ -103,7 +105,7 @@ where /// [`ServerHandle::stop`] pub fn spawn(self, handler: impl Handler) -> ServerHandle { let server_handle = self.handle(); - ServerType::spawn(self.run_async(handler)); + self.runtime.clone().spawn(self.run_async(handler)); server_handle } @@ -113,6 +115,8 @@ where ServerHandle { swansong: self.swansong.clone(), info: self.info.clone(), + received_info: OnceCell::new(), + runtime: self.runtime().into(), } } @@ -179,6 +183,7 @@ where info: self.info, binding: self.binding, http_config: self.http_config, + runtime: self.runtime, } } @@ -236,6 +241,11 @@ where .as_deref() .map_or(false, Option::is_some) } + + /// retrieve the runtime + pub fn runtime(&self) -> ServerType::Runtime { + self.runtime.clone() + } } impl Config { @@ -272,6 +282,7 @@ impl Default for Config { info: AsyncCell::shared(), binding: RwLock::new(None), http_config: HttpConfig::default(), + runtime: ServerType::runtime(), } } } diff --git a/server-common/src/config_ext.rs b/server-common/src/config_ext.rs index 65062e336e..491cdeadd4 100644 --- a/server-common/src/config_ext.rs +++ b/server-common/src/config_ext.rs @@ -3,11 +3,10 @@ use futures_lite::prelude::*; use std::{ io::ErrorKind, net::{SocketAddr, TcpListener, ToSocketAddrs}, + sync::Arc, }; use trillium::Handler; -use trillium_http::{ - transport::BoxedTransport, Conn as HttpConn, Error, Swansong, SERVICE_UNAVAILABLE, -}; +use trillium_http::{transport::BoxedTransport, Error, Swansong, SERVICE_UNAVAILABLE}; /// # Server-implementer interfaces to Config /// /// These functions are intended for use by authors of trillium servers, @@ -58,7 +57,7 @@ where /// [`trillium_http`]'s http implementation. this is the default inner /// loop for most trillium servers fn handle_stream( - &self, + self: Arc, stream: ServerType::Transport, handler: impl Handler, ) -> impl Future + Send; @@ -129,7 +128,11 @@ where self.swansong.shut_down().await } - async fn handle_stream(&self, mut stream: ServerType::Transport, handler: impl Handler) { + async fn handle_stream( + self: Arc, + mut stream: ServerType::Transport, + handler: impl Handler, + ) { if self.over_capacity() { let mut byte = [0u8]; // wait for the client to start requesting trillium::log_error!(stream.read(&mut byte).await); @@ -137,13 +140,13 @@ where return; } - let counter = self.swansong.guard(); + let guard = self.swansong.guard(); trillium::log_error!(stream.set_nodelay(self.nodelay)); let peer_ip = stream.peer_addr().ok().flatten().map(|addr| addr.ip()); - let stream = match self.acceptor.accept(stream).await { + let transport = match self.acceptor.accept(stream).await { Ok(stream) => stream, Err(e) => { log::error!("acceptor error: {:?}", e); @@ -152,9 +155,9 @@ where }; let handler = &handler; - let result = HttpConn::map_with_config( + let result = trillium_http::Conn::map_with_config( self.http_config, - stream, + transport, self.swansong.clone(), |mut conn| async { conn.set_peer_ip(peer_ip); @@ -192,7 +195,7 @@ where } }; - drop(counter); + drop(guard); } fn build_listener(&self) -> Listener diff --git a/server-common/src/lib.rs b/server-common/src/lib.rs index 0949d984f8..ca755b301c 100644 --- a/server-common/src/lib.rs +++ b/server-common/src/lib.rs @@ -23,7 +23,7 @@ If you are depending on this crate for private code that cannot be discovered through docs.rs' reverse dependencies, please open an issue. */ -pub use futures_lite::{AsyncRead, AsyncWrite}; +pub use futures_lite::{AsyncRead, AsyncWrite, Stream}; pub use trillium_http::transport::Transport; pub use url; pub use url::Url; @@ -53,3 +53,6 @@ mod arc_handler; pub(crate) use arc_handler::ArcHandler; pub use swansong::Swansong; + +mod runtime; +pub use runtime::{DroppableFuture, Runtime, RuntimeTrait}; diff --git a/server-common/src/runtime.rs b/server-common/src/runtime.rs new file mode 100644 index 0000000000..82a049c283 --- /dev/null +++ b/server-common/src/runtime.rs @@ -0,0 +1,120 @@ +use futures_lite::Stream; +use std::{ + fmt::{self, Debug, Formatter}, + future::Future, + pin::Pin, + sync::Arc, + time::Duration, +}; + +mod droppable_future; +pub use droppable_future::DroppableFuture; + +mod runtime_trait; +pub use runtime_trait::RuntimeTrait; + +mod object_safe_runtime; +use object_safe_runtime::ObjectSafeRuntime; + +/// A type-erased [`RuntimeTrait`] implementation. Think of this as an `Arc` +#[derive(Clone)] +pub struct Runtime(Arc); + +impl Debug for Runtime { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_tuple("Runtime").field(&"..").finish() + } +} + +impl Runtime { + /// Construct a new type-erased runtime object from any [`RuntimeTrait`] implementation. + /// + /// Prefer using [`from`][From::from]/[`into`][Into::into] if you don't have a concrete + /// `RuntimeTrait` in order to avoid double-arc-ing a Runtime. + pub fn new(runtime: impl RuntimeTrait) -> Self { + Self(Arc::new(runtime)) + } + + /// Spawn a future on the runtime, returning a future that has detach-on-drop semantics + /// + /// Spawned tasks conform to the following behavior: + /// + /// * detach on drop: If the returned [`DroppableFuture`] is dropped immediately, the task will + /// continue to execute until completion. + /// + /// * unwinding: If the spawned future panics, this must not propagate to the join + /// handle. Instead, the awaiting the join handle returns None in case of panic. + pub fn spawn( + &self, + fut: impl Future + Send + 'static, + ) -> DroppableFuture> + Send + 'static>>> { + let fut = RuntimeTrait::spawn(self, fut).into_inner(); + DroppableFuture::new(Box::pin(fut)) + } + + /// Wake in this amount of wall time + pub async fn delay(&self, duration: Duration) { + RuntimeTrait::delay(self, duration).await + } + + /// Returns a [`Stream`] that yields a `()` on the provided period + pub fn interval(&self, period: Duration) -> impl Stream + Send + '_ { + RuntimeTrait::interval(self, period) + } + + /// Runtime implementation hook for blocking on a top level future. + pub fn block_on(&self, fut: Fut) -> Fut::Output + where + Fut: Future, + { + RuntimeTrait::block_on(self, fut) + } + + /// Race a future against the provided duration, returning None in case of timeout. + pub async fn timeout(&self, duration: Duration, fut: Fut) -> Option + where + Fut: Future + Send, + Fut::Output: Send + 'static, + { + RuntimeTrait::timeout(self, duration, fut).await + } +} + +impl RuntimeTrait for Runtime { + async fn delay(&self, duration: Duration) { + self.0.delay(duration).await + } + + fn interval(&self, period: Duration) -> impl Stream + Send + 'static { + self.0.interval(period) + } + + fn spawn( + &self, + fut: Fut, + ) -> DroppableFuture> + Send + 'static> + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let (send, receive) = async_channel::bounded(1); + let spawn_fut = self.0.spawn(Box::pin(async move { + let _ = send.try_send(fut.await); + })); + DroppableFuture::new(Box::pin(async move { + spawn_fut.await; + receive.try_recv().ok() + })) + } + + fn block_on(&self, fut: Fut) -> Fut::Output + where + Fut: Future, + { + let (send, receive) = std::sync::mpsc::channel(); + self.0.block_on(Box::pin(async move { + let _ = send.send(fut.await); + })); + receive.recv().unwrap() + } +} diff --git a/server-common/src/runtime/droppable_future.rs b/server-common/src/runtime/droppable_future.rs new file mode 100644 index 0000000000..5d82b3798e --- /dev/null +++ b/server-common/src/runtime/droppable_future.rs @@ -0,0 +1,41 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use pin_project_lite::pin_project; + +pin_project! { + /// A wrapper type for futures that do not need to be polled but still can be awaited. + /// + /// This exists to silence the default `#[must_use]` that anonymous async functions return + /// + /// Futures contained by this type must conform to the semantics of trillium join handles described + /// at [RuntimeTrait::spawn]. + #[derive(Debug, Clone)] + pub struct DroppableFuture { + #[pin] future: T + } +} +impl DroppableFuture { + /// Removes the #[must_use] for a future. + /// + /// This must only be called with a join-handle type future that does not depend on polling to + /// execute. + pub fn new(future: T) -> Self { + Self { future } + } + + /// Returns the inner future. + pub fn into_inner(self) -> T { + self.future + } +} +impl Future for DroppableFuture { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.project().future.poll(cx) + } +} diff --git a/server-common/src/runtime/object_safe_runtime.rs b/server-common/src/runtime/object_safe_runtime.rs new file mode 100644 index 0000000000..5fe22c4dab --- /dev/null +++ b/server-common/src/runtime/object_safe_runtime.rs @@ -0,0 +1,57 @@ +use super::RuntimeTrait; +use futures_lite::Stream; +use std::{future::Future, pin::Pin, time::Duration}; + +pub(super) trait ObjectSafeRuntime: Send + Sync + 'static { + fn spawn( + &self, + fut: Pin + Send + 'static>>, + ) -> Pin> + Send + 'static>>; + fn delay<'runtime, 'fut>( + &'runtime self, + duration: Duration, + ) -> Pin + Send + 'fut>> + where + 'runtime: 'fut, + Self: 'fut; + fn interval(&self, period: Duration) -> Pin + Send + 'static>>; + fn block_on<'runtime, 'fut>(&'runtime self, fut: Pin + 'fut>>) + where + 'runtime: 'fut, + Self: 'fut; +} + +impl ObjectSafeRuntime for R +where + R: RuntimeTrait + Send + Sync + 'static, +{ + fn spawn( + &self, + fut: Pin + Send + 'static>>, + ) -> Pin> + Send + 'static>> { + Box::pin(RuntimeTrait::spawn(self, Box::pin(fut))) + } + + fn delay<'runtime, 'fut>( + &'runtime self, + duration: Duration, + ) -> Pin + Send + 'fut>> + where + 'runtime: 'fut, + Self: 'fut, + { + Box::pin(RuntimeTrait::delay(self, duration)) + } + + fn interval(&self, period: Duration) -> Pin + Send + 'static>> { + Box::pin(RuntimeTrait::interval(self, period)) + } + + fn block_on<'runtime, 'fut>(&'runtime self, fut: Pin + 'fut>>) + where + 'runtime: 'fut, + Self: 'fut, + { + RuntimeTrait::block_on(self, fut) + } +} diff --git a/server-common/src/runtime/runtime_trait.rs b/server-common/src/runtime/runtime_trait.rs new file mode 100644 index 0000000000..5889e1314e --- /dev/null +++ b/server-common/src/runtime/runtime_trait.rs @@ -0,0 +1,54 @@ +use super::{DroppableFuture, Runtime}; +use futures_lite::{FutureExt, Stream}; +use std::{future::Future, time::Duration}; + +/// A trait that covers async runtime behavior. +/// +/// You likely do not need to name this type. For a type-erased runtime, see [`Runtime`] +pub trait RuntimeTrait: Into + Clone + Send + Sync + 'static { + /// Spawn a future on the runtime, returning a future that has detach-on-drop semantics + /// + /// As the various runtimes each has different behavior for spawn, implementations of this trait + /// are expected to conform to the following: + /// + /// * detach on drop: If the returned [`DroppableFuture`] is dropped immediately, the task will + /// continue to execute until completion. + /// + /// * unwinding: If the spawned future panics, this must not propagate to the join + /// handle. Instead, the awaiting the join handle returns None in case of panic. + fn spawn( + &self, + fut: Fut, + ) -> DroppableFuture> + Send + 'static> + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static; + + /// Wake in this amount of wall time + fn delay(&self, duration: Duration) -> impl Future + Send; + + /// Returns a [`Stream`] that yields a `()` on the provided period + fn interval(&self, period: Duration) -> impl Stream + Send + 'static; + + /// Runtime implementation hook for blocking on a top level future. + fn block_on(&self, fut: Fut) -> Fut::Output + where + Fut: Future; + + /// Race a future against the provided duration, returning None in case of timeout. + fn timeout<'runtime, 'fut, Fut>( + &'runtime self, + duration: Duration, + fut: Fut, + ) -> impl Future> + Send + 'fut + where + Fut: Future + Send + 'fut, + Fut::Output: Send + 'static, + 'runtime: 'fut, + { + async move { Some(fut.await) }.race(async move { + self.delay(duration).await; + None + }) + } +} diff --git a/server-common/src/server.rs b/server-common/src/server.rs index dce32c4bb0..14f903f0da 100644 --- a/server-common/src/server.rs +++ b/server-common/src/server.rs @@ -1,4 +1,4 @@ -use crate::{Acceptor, ArcHandler, Config, ConfigExt, Swansong, Transport}; +use crate::{Acceptor, ArcHandler, Config, ConfigExt, RuntimeTrait, Swansong, Transport}; use std::{future::Future, io::Result, sync::Arc}; use trillium::{Handler, Info}; @@ -11,6 +11,9 @@ pub trait Server: Sized + Send + Sync + 'static { /// like TcpStream or UnixStream. See [`Transport`] type Transport: Transport; + /// The [`RuntimeTrait`] for this `Server`. + type Runtime: RuntimeTrait; + /// The description of this server, to be appended to the Info and potentially logged. const DESCRIPTION: &'static str; @@ -28,6 +31,9 @@ pub trait Server: Sized + Send + Sync + 'static { async {} } + /// Return this `Server`'s `Runtime` + fn runtime() -> Self::Runtime; + /// Build a listener from the config. The default logic for this /// is described elsewhere. To override the default logic, server /// implementations could potentially implement this directly. To @@ -105,19 +111,16 @@ pub trait Server: Sized + Send + Sync + 'static { async {} } - /// Runtime implementation hook for spawning a task. - fn spawn(fut: impl Future + Send + 'static); - - /// Runtime implementation hook for blocking on a top level future. - fn block_on(fut: impl Future + 'static); - /// Run a trillium application from a sync context fn run(config: Config, handler: H) where A: Acceptor, H: Handler, { - Self::block_on(Self::run_async(config, handler)) + config + .runtime + .clone() + .block_on(async move { Self::run_async(config, handler).await }); } /// Run a trillium application from an async context. The default @@ -129,14 +132,14 @@ pub trait Server: Sized + Send + Sync + 'static { H: Handler, { async move { + let runtime = config.runtime.clone(); if config.should_register_signals() { #[cfg(unix)] - Self::spawn(Self::handle_signals(config.swansong())); + runtime.spawn(Self::handle_signals(config.swansong())); #[cfg(not(unix))] log::error!("signals handling not supported on windows yet"); } - let mut listener = Self::build_listener(&config); let mut info = Self::info(&listener); info.server_description_mut().push_str(Self::DESCRIPTION); @@ -144,13 +147,14 @@ pub trait Server: Sized + Send + Sync + 'static { config.info.set(Arc::new(info)); let config = Arc::new(config); let handler = ArcHandler::new(handler); + let swansong = &config.swansong; - while let Some(stream) = config.swansong.interrupt(Self::accept(&mut listener)).await { - match stream { + while let Some(transport) = swansong.interrupt(Self::accept(&mut listener)).await { + match transport { Ok(stream) => { let config = Arc::clone(&config); let handler = ArcHandler::clone(&handler); - Self::spawn(async move { config.handle_stream(stream, handler).await }) + runtime.spawn(config.handle_stream(stream, handler)); } Err(e) => log::error!("tcp error: {}", e), } diff --git a/server-common/src/server_handle.rs b/server-common/src/server_handle.rs index e3e21f6223..3062c63f5c 100644 --- a/server-common/src/server_handle.rs +++ b/server-common/src/server_handle.rs @@ -1,5 +1,6 @@ +use crate::Runtime; use async_cell::sync::AsyncCell; -use std::{future::IntoFuture, sync::Arc}; +use std::{cell::OnceCell, future::IntoFuture, sync::Arc}; use swansong::{ShutdownCompletion, Swansong}; use trillium::Info; @@ -10,23 +11,34 @@ use trillium::Info; pub struct ServerHandle { pub(crate) swansong: Swansong, pub(crate) info: Arc>>, + pub(crate) received_info: OnceCell>, + pub(crate) runtime: Runtime, } impl ServerHandle { /// await server start and retrieve the server's [`Info`] - pub async fn info(&self) -> Arc { - self.info.get().await + pub async fn info(&self) -> &Info { + if let Some(info) = self.received_info.get() { + return info; + } + let arc_info = self.info.get().await; + self.received_info.get_or_init(|| arc_info) } - /// stop server and wait for it to shut down gracefully - pub async fn shut_down(&self) { - self.swansong.shut_down().await; + /// stop server and return a future that can be awaited for it to shut down gracefully + pub fn shut_down(&self) -> ShutdownCompletion { + self.swansong.shut_down() } /// retrieves a clone of the [`Swansong`] used by this server pub fn swansong(&self) -> Swansong { self.swansong.clone() } + + /// retrieves a runtime + pub fn runtime(&self) -> Runtime { + self.runtime.clone() + } } impl IntoFuture for ServerHandle { diff --git a/smol/Cargo.toml b/smol/Cargo.toml index b6da8b8138..487e59214f 100644 --- a/smol/Cargo.toml +++ b/smol/Cargo.toml @@ -11,9 +11,11 @@ keywords = ["trillium", "framework", "async"] categories = ["web-programming::http-server", "web-programming"] [dependencies] -async-global-executor = { version = "2.4.1", features = ["async-io"] } +async-executor = "1.10.0" +async-global-executor = {version = "2.4.1", features = ["async-io"] } async-io = "2.2.2" async-net = "2.0.0" +async-task = "4.7.0" futures-lite = "2.1.0" log = "0.4.20" trillium = { path = "../trillium", version = "0.2.19" } @@ -27,6 +29,8 @@ signal-hook-async-std = "0.2.2" [dev-dependencies] env_logger = "0.11.0" +tempfile = "3.10.1" +test-harness = "0.2.0" trillium-client = { path = "../client" } trillium-logger = { path = "../logger" } trillium-testing = { path = "../testing" } diff --git a/smol/examples/smol.rs b/smol/examples/smol.rs index 02567684b5..2d32e8ff1e 100644 --- a/smol/examples/smol.rs +++ b/smol/examples/smol.rs @@ -1,25 +1,30 @@ -use trillium_smol::async_io::Timer; +use std::time::Duration; +use trillium::{Conn, Handler}; +use trillium_logger::Logger; +use trillium_smol::SmolRuntime; -pub fn app() -> impl trillium::Handler { - ( - trillium_logger::Logger::new(), - |conn: trillium::Conn| async move { - let response = async_global_executor::spawn(async { - Timer::after(std::time::Duration::from_millis(10)).await; +pub fn app() -> impl Handler { + (Logger::new(), |conn: Conn| async move { + let runtime = SmolRuntime::default(); + let response = runtime + .clone() + .spawn(async move { + runtime.delay(Duration::from_millis(100)).await; "successfully spawned a task" }) - .await; + .await + .unwrap_or_default(); - conn.ok(response) - }, - ) + conn.ok(response) + }) } + pub fn main() { env_logger::init(); trillium_smol::run(app()); } -#[cfg(all(test, feature = "smol"))] +#[cfg(test)] mod tests { use trillium_testing::prelude::*; #[test] diff --git a/smol/src/client.rs b/smol/src/client.rs index 2414959621..317f061aa7 100644 --- a/smol/src/client.rs +++ b/smol/src/client.rs @@ -1,9 +1,6 @@ -use crate::SmolTransport; +use crate::{SmolRuntime, SmolTransport}; use async_net::TcpStream; -use std::{ - future::Future, - io::{Error, ErrorKind, Result}, -}; +use std::io::{Error, ErrorKind, Result}; use trillium_server_common::{ url::{Host, Url}, Connector, Transport, @@ -45,6 +42,11 @@ impl ClientConfig { impl Connector for ClientConfig { type Transport = SmolTransport; + type Runtime = SmolRuntime; + + fn runtime(&self) -> Self::Runtime { + SmolRuntime::default() + } async fn connect(&self, url: &Url) -> Result { if url.scheme() != "http" { @@ -79,12 +81,19 @@ impl Connector for ClientConfig { Ok(tcp) } +} + +#[cfg(unix)] +impl Connector for SmolTransport { + type Transport = Self; + + type Runtime = SmolRuntime; - fn spawn + Send + 'static>(&self, fut: Fut) { - async_global_executor::spawn(fut).detach(); + async fn connect(&self, _url: &Url) -> Result { + Ok(self.clone()) } - async fn delay(&self, duration: std::time::Duration) { - async_io::Timer::after(duration).await; + fn runtime(&self) -> Self::Runtime { + SmolRuntime::default() } } diff --git a/smol/src/lib.rs b/smol/src/lib.rs index 3bb8332137..9c68f605a4 100644 --- a/smol/src/lib.rs +++ b/smol/src/lib.rs @@ -60,7 +60,7 @@ trillium_testing::with_server("ok", |url| async move { */ use trillium::Handler; -pub use trillium_server_common::{Binding, Swansong}; +pub use trillium_server_common::{Binding, Connector, Runtime, RuntimeTrait, Swansong}; mod client; pub use client::ClientConfig; @@ -75,6 +75,9 @@ pub use async_global_executor; pub use async_io; pub use async_net; +mod runtime; +pub use runtime::SmolRuntime; + /** # Runs a trillium handler in a sync context with default config @@ -140,8 +143,3 @@ See [`trillium_server_common::Config`] for more details pub fn config() -> Config<()> { Config::new() } - -/// spawn and detach a Future that returns () -pub fn spawn + Send + 'static>(future: Fut) { - async_global_executor::spawn(future).detach(); -} diff --git a/smol/src/runtime.rs b/smol/src/runtime.rs new file mode 100644 index 0000000000..12396a3e36 --- /dev/null +++ b/smol/src/runtime.rs @@ -0,0 +1,110 @@ +use async_io::Timer; +use async_task::Task; +use futures_lite::{FutureExt, Stream, StreamExt}; +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, + time::Duration, +}; +use trillium_server_common::{DroppableFuture, Runtime, RuntimeTrait}; + +/// Runtime for Smol +#[derive(Debug, Clone, Copy, Default)] +pub struct SmolRuntime(()); + +struct DetachOnDrop(Option>); +impl Future for DetachOnDrop { + type Output = Output; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + Pin::new(self.0.as_mut().unwrap()).poll(cx) + } +} + +impl Drop for DetachOnDrop { + fn drop(&mut self) { + if let Some(task) = self.0.take() { + task.detach(); + } + } +} + +impl RuntimeTrait for SmolRuntime { + fn spawn( + &self, + fut: Fut, + ) -> DroppableFuture> + Send + 'static> + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let join_handle = DetachOnDrop(Some(async_global_executor::spawn(fut))); + DroppableFuture::new(async move { join_handle.catch_unwind().await.ok() }) + } + + async fn delay(&self, duration: Duration) { + Timer::after(duration).await; + } + + fn interval(&self, period: Duration) -> impl Stream + Send + 'static { + Timer::interval(period).map(|_| ()) + } + + fn block_on(&self, fut: Fut) -> Fut::Output { + async_global_executor::block_on(fut) + } +} + +impl SmolRuntime { + /// Spawn a future on the runtime, returning a future that has detach-on-drop semantics + /// + /// Spawned tasks conform to the following behavior: + /// + /// * detach on drop: If the returned [`DroppableFuture`] is dropped immediately, the task will + /// continue to execute until completion. + /// + /// * unwinding: If the spawned future panics, this must not propagate to the join + /// handle. Instead, the awaiting the join handle returns None in case of panic. + pub fn spawn( + &self, + fut: Fut, + ) -> DroppableFuture> + Send + 'static> + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let join_handle = DetachOnDrop(Some(async_global_executor::spawn(fut))); + DroppableFuture::new(async move { join_handle.catch_unwind().await.ok() }) + } + + /// Wake in this amount of wall time + pub async fn delay(&self, duration: Duration) { + Timer::after(duration).await; + } + + /// Returns a [`Stream`] that yields a `()` on the provided period + pub fn interval(&self, period: Duration) -> impl Stream + Send + 'static { + Timer::interval(period).map(|_| ()) + } + + /// Runtime implementation hook for blocking on a top level future. + pub fn block_on(&self, fut: Fut) -> Fut::Output { + RuntimeTrait::block_on(self, fut) + } + + /// Race a future against the provided duration, returning None in case of timeout. + pub async fn timeout(&self, duration: Duration, fut: Fut) -> Option + where + Fut: Future + Send, + Fut::Output: Send + 'static, + { + RuntimeTrait::timeout(self, duration, fut).await + } +} + +impl From for Runtime { + fn from(value: SmolRuntime) -> Self { + Runtime::new(value) + } +} diff --git a/smol/src/server.rs b/smol/src/server.rs index d2832310bc..d607ee802a 100644 --- a/smol/src/server.rs +++ b/smol/src/server.rs @@ -3,9 +3,9 @@ mod unix; #[cfg(unix)] pub use unix::SmolServer; -#[cfg(not(unix))] mod tcp; + #[cfg(not(unix))] -pub use tcp::SmolServer; +pub use tcp::SmolTcpServer as SmolServer; pub type Config = trillium_server_common::Config; diff --git a/smol/src/server/tcp.rs b/smol/src/server/tcp.rs index fa49d8dafd..7ff8dd2711 100644 --- a/smol/src/server/tcp.rs +++ b/smol/src/server/tcp.rs @@ -1,21 +1,21 @@ -use crate::SmolTransport; -use async_global_executor::{block_on, spawn}; +use crate::{SmolRuntime, SmolTransport}; use async_net::{TcpListener, TcpStream}; -use futures_lite::prelude::*; -use std::{convert::TryInto, env, io::Result}; +use std::{convert::TryInto, env, io::Result, net}; use trillium::Info; -use trillium_server_common::Server; +use trillium_server_common::{Server, Url}; #[derive(Debug)] -pub struct SmolServer(TcpListener); -impl From for SmolServer { +pub struct SmolTcpServer(TcpListener); +impl From for SmolTcpServer { fn from(value: TcpListener) -> Self { Self(value) } } -impl Server for SmolServer { +impl Server for SmolTcpServer { type Transport = SmolTransport; + type Runtime = SmolRuntime; + const DESCRIPTION: &'static str = concat!( " (", env!("CARGO_PKG_NAME"), @@ -28,19 +28,20 @@ impl Server for SmolServer { self.0.accept().await.map(|(t, _)| t.into()) } - fn listener_from_tcp(tcp: std::net::TcpListener) -> Self { + fn listener_from_tcp(tcp: net::TcpListener) -> Self { Self(tcp.try_into().unwrap()) } fn info(&self) -> Info { - self.0.local_addr().unwrap().into() - } - - fn spawn(fut: impl Future + Send + 'static) { - spawn(fut).detach(); + let local_addr = self.0.local_addr().unwrap(); + let mut info = Info::from(local_addr); + if let Ok(url) = Url::parse(&format!("http://{local_addr}")) { + info.state_mut().insert(url); + } + info } - fn block_on(fut: impl Future + 'static) { - block_on(fut) + fn runtime() -> Self::Runtime { + SmolRuntime::default() } } diff --git a/smol/src/server/unix.rs b/smol/src/server/unix.rs index f81b07bf4f..36cea416cb 100644 --- a/smol/src/server/unix.rs +++ b/smol/src/server/unix.rs @@ -1,5 +1,4 @@ -use crate::SmolTransport; -use async_global_executor::{block_on, spawn}; +use crate::{SmolRuntime, SmolTransport}; use async_net::{ unix::{UnixListener, UnixStream}, TcpListener, TcpStream, @@ -9,7 +8,7 @@ use std::{env, io::Result}; use trillium::{log_error, Info}; use trillium_server_common::{ Binding::{self, *}, - Server, Swansong, + Server, Swansong, Url, }; #[derive(Debug, Clone)] @@ -28,6 +27,7 @@ impl From for SmolServer { #[cfg(unix)] impl Server for SmolServer { type Transport = Binding, SmolTransport>; + type Runtime = SmolRuntime; const DESCRIPTION: &'static str = concat!( " (", env!("CARGO_PKG_NAME"), @@ -36,6 +36,10 @@ impl Server for SmolServer { ")" ); + fn runtime() -> Self::Runtime { + SmolRuntime::default() + } + async fn handle_signals(swansong: Swansong) { use signal_hook::consts::signal::*; use signal_hook_async_std::Signals; @@ -70,19 +74,18 @@ impl Server for SmolServer { fn info(&self) -> Info { match &self.0 { - Tcp(t) => t.local_addr().unwrap().into(), + Tcp(t) => { + let local_addr = t.local_addr().unwrap(); + let mut info = Info::from(local_addr); + if let Ok(url) = Url::parse(&format!("http://{local_addr}")) { + info.state_mut().insert(url); + } + info + } Unix(u) => u.local_addr().unwrap().into(), } } - fn spawn(fut: impl Future + Send + 'static) { - spawn(fut).detach(); - } - - fn block_on(fut: impl Future + 'static) { - block_on(fut) - } - async fn clean_up(self) { if let Unix(u) = &self.0 { if let Ok(local) = u.local_addr() { diff --git a/smol/src/transport.rs b/smol/src/transport.rs index 38eff92234..22646a797e 100644 --- a/smol/src/transport.rs +++ b/smol/src/transport.rs @@ -43,3 +43,10 @@ impl Transport for SmolTransport { #[cfg(unix)] impl Transport for SmolTransport {} +#[cfg(unix)] +impl SmolTransport { + /// connnect to the provided path as a unix domain socket + pub async fn connect_unix(path: impl AsRef) -> Result { + async_net::unix::UnixStream::connect(path).await.map(Self) + } +} diff --git a/smol/tests/unix_stream.rs b/smol/tests/unix_stream.rs new file mode 100644 index 0000000000..2a21865418 --- /dev/null +++ b/smol/tests/unix_stream.rs @@ -0,0 +1,21 @@ +#![cfg(unix)] +use test_harness::test; +use trillium_client::Client; +use trillium_smol::{config, SmolTransport}; +use trillium_testing::harness; + +#[test(harness)] +async fn smoke() { + let temp_dir = tempfile::tempdir().unwrap(); + let path = temp_dir.path().join("socket"); + + let handle = config().with_host(path.to_str().unwrap()).spawn("ok"); + handle.info().await; + + let stream = SmolTransport::connect_unix(path).await.unwrap(); + let mut conn = Client::new(stream).get("http://localhost/").await.unwrap(); + + assert_eq!(conn.status().unwrap(), 200); + assert_eq!(conn.response_body().read_string().await.unwrap(), "ok"); + handle.shut_down().await; +} diff --git a/testing/Cargo.toml b/testing/Cargo.toml index b0dda2c357..10e6db4254 100644 --- a/testing/Cargo.toml +++ b/testing/Cargo.toml @@ -22,7 +22,6 @@ default = [] [dependencies] async-dup = "1.2.4" futures-lite = "2.1.0" -portpicker = "0.1.1" trillium = { path = "../trillium", version = "0.2.19" } trillium-http = { path = "../http", version = "0.3.16" } trillium-server-common = { path = "../server-common", version = "0.5.2" } @@ -33,6 +32,7 @@ trillium-macros = { version = "0.0.6", path = "../macros" } dashmap = "5.5.3" once_cell = "1.19.0" fastrand = "2.0.1" +log = "0.4.21" [dependencies.trillium-smol] path = "../smol" diff --git a/testing/src/lib.rs b/testing/src/lib.rs index 4128b7848a..bd55fa74a9 100644 --- a/testing/src/lib.rs +++ b/testing/src/lib.rs @@ -65,7 +65,7 @@ trillium-testing = { version = "0.2", features = ["smol"] } mod assertions; mod test_transport; -use std::future::{Future, IntoFuture}; +use std::future::Future; use std::process::Termination; pub use test_transport::TestTransport; @@ -90,38 +90,26 @@ pub use trillium::{Method, Status}; pub use url::Url; +/// runs the future to completion on the current thread +pub fn block_on(fut: Fut) -> Fut::Output { + runtime().block_on(fut) +} + /// initialize a handler pub fn init(handler: &mut impl trillium::Handler) { let mut info = "testing".into(); - block_on(handler.init(&mut info)) + block_on(async move { handler.init(&mut info).await }) } // these exports are used by macros pub use futures_lite; -pub use futures_lite::{AsyncRead, AsyncReadExt, AsyncWrite}; +pub use futures_lite::{AsyncRead, AsyncReadExt, AsyncWrite, Stream}; mod server_connector; pub use server_connector::{connector, ServerConnector}; use trillium_server_common::Config; -pub use trillium_server_common::{ArcedConnector, Connector, Server}; - -#[derive(Debug)] -/// A droppable future -/// -/// This only exists because of the #[must_use] on futures. The task will run to completion whether -/// or not this future is awaited. -pub struct SpawnHandle(F); -impl IntoFuture for SpawnHandle -where - F: Future, -{ - type IntoFuture = F; - type Output = F::Output; - fn into_future(self) -> Self::IntoFuture { - self.0 - } -} +pub use trillium_server_common::{ArcedConnector, Connector, Runtime, RuntimeTrait, Server}; cfg_if::cfg_if! { if #[cfg(feature = "smol")] { @@ -130,142 +118,72 @@ cfg_if::cfg_if! { trillium_smol::config() } - /// smol-based spawn variant that finishes whether or not the returned future is dropped - pub fn spawn(future: Fut) -> SpawnHandle>> - where - Fut: Future + Send + 'static, - Out: Send + 'static - { - let (tx, rx) = async_channel::bounded::(1); - trillium_smol::async_global_executor::spawn(async move { let _ = tx.send(future.await).await; }).detach(); - SpawnHandle(async move { - let rx = rx; - rx.recv().await.ok() - }) - } - /// runtime client config pub fn client_config() -> impl Connector { - ClientConfig::default() + trillium_smol::ClientConfig::default() } - pub use trillium_smol::async_global_executor::block_on; - pub use trillium_smol::ClientConfig; - /// a future that wakes after this amount of time - pub async fn delay(duration: std::time::Duration) { - trillium_smol::async_io::Timer::after(duration).await; + /// smol runtime + pub fn runtime() -> impl RuntimeTrait { + trillium_smol::SmolRuntime::default() } - + pub(crate) use trillium_smol::SmolRuntime as RuntimeType; } else if #[cfg(feature = "async-std")] { /// runtime server config pub fn config() -> Config { trillium_async_std::config() } - pub use trillium_async_std::async_std::task::block_on; - pub use trillium_async_std::ClientConfig; - - /// async-std-based spawn variant that finishes whether or not the returned future is dropped - pub fn spawn(future: Fut) -> SpawnHandle>> - where - Fut: Future + Send + 'static, - Out: Send + 'static - { - let (tx, rx) = async_channel::bounded::(1); - trillium_async_std::async_std::task::spawn(async move { let _ = tx.send(future.await).await; }); - SpawnHandle(async move { - let rx = rx; - rx.recv().await.ok() - }) - } /// runtime client config pub fn client_config() -> impl Connector { - ClientConfig::default() + trillium_async_std::ClientConfig::default() } - - /// a future that wakes after this amount of time - pub async fn delay(duration: std::time::Duration) { - let _ = trillium_async_std::async_std::future::timeout( - duration, - std::future::pending::<()>() - ).await; + /// async std runtime + pub fn runtime() -> impl RuntimeTrait { + trillium_async_std::AsyncStdRuntime::default() } + pub(crate) use trillium_async_std::AsyncStdRuntime as RuntimeType; } else if #[cfg(feature = "tokio")] { /// runtime server config pub fn config() -> Config { trillium_tokio::config() } - pub use trillium_tokio::ClientConfig; - pub use trillium_tokio::block_on; - /// tokio-based spawn variant that finishes whether or not the returned future is dropped - pub fn spawn(future: Fut) -> SpawnHandle>> - where - Fut: Future + Send + 'static, - Out: Send + 'static - { - let (tx, rx) = async_channel::bounded::(1); - trillium_tokio::tokio::task::spawn(async move { let _ = tx.send(future.await).await; }); - SpawnHandle(async move { - let rx = rx; - rx.recv().await.ok() - }) - } - /// runtime client config + + /// tokio client config pub fn client_config() -> impl Connector { - ClientConfig::default() + trillium_tokio::ClientConfig::default() } - /// a future that wakes after this amount of time - pub async fn delay(duration: std::time::Duration) { - trillium_tokio::tokio::time::sleep(duration).await; + /// tokio runtime + pub fn runtime() -> impl RuntimeTrait { + trillium_tokio::TokioRuntime::default() } + pub(crate) use trillium_tokio::TokioRuntime as RuntimeType; } else { /// runtime server config pub fn config() -> Config { Config::::new() } - pub use RuntimelessClientConfig as ClientConfig; - /// generic client config pub fn client_config() -> impl Connector { RuntimelessClientConfig::default() } - pub use futures_lite::future::block_on; - - /// fake runtimeless spawn that finishes whether or not the future is dropped - pub fn spawn(future: Fut) -> SpawnHandle>> - where - Fut: Future + Send + 'static, - Out: Send + 'static - { - let (tx, rx) = async_channel::bounded::(1); - std::thread::spawn(move || { let _ = tx.send_blocking(block_on(future)); }); - SpawnHandle(async move { - let rx = rx; - rx.recv().await.ok() - }) - } - - /// a future that wakes after this amount of time - pub async fn delay(duration: std::time::Duration) { - let (sender, receiver) = async_channel::bounded::<()>(1); - std::thread::spawn(move || { - std::thread::sleep(duration); - let _ = sender.send_blocking(()); - }); - - let _ = receiver.recv().await; + /// generic runtime + pub fn runtime() -> impl RuntimeTrait { + RuntimelessRuntime::default() } - } + + pub(crate) use RuntimelessRuntime as RuntimeType; + } } mod with_server; pub use with_server::{with_server, with_transport}; mod runtimeless; -pub use runtimeless::{RuntimelessClientConfig, RuntimelessServer}; +pub use runtimeless::{RuntimelessClientConfig, RuntimelessRuntime, RuntimelessServer}; /// a sponge Result pub type TestResult = Result<(), Box>; diff --git a/testing/src/runtimeless.rs b/testing/src/runtimeless.rs index 00ec6fc6e8..9dfc23cee1 100644 --- a/testing/src/runtimeless.rs +++ b/testing/src/runtimeless.rs @@ -1,158 +1,16 @@ -use crate::{spawn, TestTransport}; +use crate::TestTransport; use async_channel::{Receiver, Sender}; use dashmap::DashMap; use once_cell::sync::Lazy; -use std::{ - future::Future, - io::{Error, ErrorKind, Result}, -}; -use trillium::Info; -use trillium_server_common::{Acceptor, Config, ConfigExt, Connector, Server}; -use url::Url; type Servers = Lazy, Receiver)>>; static SERVERS: Servers = Lazy::new(Default::default); -/// A [`Server`] for testing that does not depend on any runtime -#[derive(Debug)] -pub struct RuntimelessServer { - host: String, - port: u16, - channel: Receiver, -} +mod runtime; +pub use runtime::RuntimelessRuntime; -impl Server for RuntimelessServer { - type Transport = TestTransport; - const DESCRIPTION: &'static str = "test server"; - async fn accept(&mut self) -> Result { - self.channel - .recv() - .await - .map_err(|e| Error::new(ErrorKind::Other, e.to_string())) - } +mod server; +pub use server::RuntimelessServer; - fn build_listener(config: &Config) -> Self - where - A: Acceptor, - { - let mut port = config.port(); - let host = config.host(); - if port == 0 { - loop { - port = fastrand::u16(..); - if !SERVERS.contains_key(&(host.clone(), port)) { - break; - } - } - } - - let entry = SERVERS - .entry((host.clone(), port)) - .or_insert_with(async_channel::unbounded); - - let (_, channel) = entry.value(); - - Self { - host, - channel: channel.clone(), - port, - } - } - - fn info(&self) -> Info { - Info::from(&*format!("{}:{}", &self.host, &self.port)) - } - - fn block_on(fut: impl Future + 'static) { - crate::block_on(fut) - } - - fn spawn(fut: impl Future + Send + 'static) { - spawn(fut); - } -} - -impl Drop for RuntimelessServer { - fn drop(&mut self) { - SERVERS.remove(&(self.host.clone(), self.port)); - } -} - -/// An in-memory Connector to use with GenericServer. -#[derive(Default, Debug, Clone, Copy)] -pub struct RuntimelessClientConfig(()); - -impl RuntimelessClientConfig { - /// constructs a GenericClientConfig - pub fn new() -> Self { - Self(()) - } -} - -impl Connector for RuntimelessClientConfig { - type Transport = TestTransport; - async fn connect(&self, url: &Url) -> Result { - let (tx, _) = &*SERVERS - .get(&( - url.host_str().unwrap().to_string(), - url.port_or_known_default().unwrap(), - )) - .ok_or(Error::new(ErrorKind::AddrNotAvailable, "not available"))?; - let (client_transport, server_transport) = TestTransport::new(); - tx.send(server_transport).await.unwrap(); - Ok(client_transport) - } - - fn spawn + Send + 'static>(&self, fut: Fut) { - spawn(fut); - } - - async fn delay(&self, duration: std::time::Duration) { - let (sender, receiver) = async_channel::bounded::<()>(1); - std::thread::spawn(move || { - std::thread::sleep(duration); - let _ = sender.send_blocking(()); - }); - - let _ = receiver.recv().await; - } -} - -#[cfg(test)] -mod test { - use super::*; - use crate::{harness, TestResult}; - use test_harness::test; - #[test(harness)] - async fn round_trip() -> TestResult { - let handle1 = Config::::new() - .with_host("host.com") - .with_port(80) - .spawn("server 1"); - handle1.info().await; - - let handle2 = Config::::new() - .with_host("other_host.com") - .with_port(80) - .spawn("server 2"); - handle2.info().await; - - let client = trillium_client::Client::new(RuntimelessClientConfig::default()); - let mut conn = client.get("http://host.com").await?; - assert_eq!(conn.response_body().await?, "server 1"); - - let mut conn = client.get("http://other_host.com").await?; - assert_eq!(conn.response_body().await?, "server 2"); - - handle1.shut_down().await; - assert!(client.get("http://host.com").await.is_err()); - assert!(client.get("http://other_host.com").await.is_ok()); - - handle2.shut_down().await; - assert!(client.get("http://other_host.com").await.is_err()); - - assert!(SERVERS.is_empty()); - - Ok(()) - } -} +mod client; +pub use client::RuntimelessClientConfig; diff --git a/testing/src/runtimeless/client.rs b/testing/src/runtimeless/client.rs new file mode 100644 index 0000000000..1240071303 --- /dev/null +++ b/testing/src/runtimeless/client.rs @@ -0,0 +1,36 @@ +use super::SERVERS; +use crate::{RuntimelessRuntime, TestTransport}; +use std::io::{Error, ErrorKind, Result}; +use trillium_server_common::Connector; +use url::Url; + +/// An in-memory Connector to use with RuntimelessServer. +#[derive(Default, Debug, Clone, Copy)] +pub struct RuntimelessClientConfig(()); + +impl RuntimelessClientConfig { + /// constructs a RuntimelessClientConfig + pub fn new() -> Self { + Self(()) + } +} + +impl Connector for RuntimelessClientConfig { + type Transport = TestTransport; + type Runtime = RuntimelessRuntime; + async fn connect(&self, url: &Url) -> Result { + let (tx, _) = &*SERVERS + .get(&( + url.host_str().unwrap().to_string(), + url.port_or_known_default().unwrap(), + )) + .ok_or(Error::new(ErrorKind::AddrNotAvailable, "not available"))?; + let (client_transport, server_transport) = TestTransport::new(); + tx.send(server_transport).await.unwrap(); + Ok(client_transport) + } + + fn runtime(&self) -> Self::Runtime { + RuntimelessRuntime::default() + } +} diff --git a/testing/src/runtimeless/runtime.rs b/testing/src/runtimeless/runtime.rs new file mode 100644 index 0000000000..13901df928 --- /dev/null +++ b/testing/src/runtimeless/runtime.rs @@ -0,0 +1,112 @@ +use futures_lite::{future, Stream}; +use std::{future::Future, thread, time::Duration}; +use trillium_server_common::{DroppableFuture, Runtime, RuntimeTrait}; + +/// a runtime that isn't a runtime +#[derive(Debug, Clone, Copy, Default)] +pub struct RuntimelessRuntime(()); +impl RuntimeTrait for RuntimelessRuntime { + fn spawn( + &self, + fut: Fut, + ) -> DroppableFuture> + Send + 'static> + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let rt = *self; + let (send, receive) = async_channel::bounded(1); + thread::spawn(move || { + let _ = send.send_blocking(rt.block_on(fut)); + }); + DroppableFuture::new(async move { receive.recv().await.ok() }) + } + async fn delay(&self, duration: Duration) { + let (send, receive) = async_channel::bounded(1); + thread::spawn(move || { + thread::sleep(duration); + let _ = send.send_blocking(()); + }); + let _ = receive.recv().await; + } + + fn interval(&self, period: Duration) -> impl Stream + Send + 'static { + let (send, receive) = async_channel::bounded(1); + thread::spawn(move || loop { + thread::sleep(period); + if send.send_blocking(()).is_err() { + break; + } + }); + + receive + } + + fn block_on(&self, fut: Fut) -> Fut::Output { + future::block_on(fut) + } +} +impl From for Runtime { + fn from(value: RuntimelessRuntime) -> Self { + Runtime::new(value) + } +} +impl RuntimelessRuntime { + /// Spawn a future on the runtime, returning a future that has detach-on-drop semantics + /// + /// Spawned tasks conform to the following behavior: + /// + /// * detach on drop: If the returned [`DroppableFuture`] is dropped immediately, the task will + /// continue to execute until completion. + /// + /// * unwinding: If the spawned future panics, this must not propagate to the join + /// handle. Instead, the awaiting the join handle returns None in case of panic. + pub fn spawn( + &self, + fut: Fut, + ) -> DroppableFuture> + Send + 'static> + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let rt = *self; + let (send, receive) = async_channel::bounded(1); + thread::spawn(move || { + let _ = send.send_blocking(rt.block_on(fut)); + }); + DroppableFuture::new(async move { receive.recv().await.ok() }) + } + + /// Wake in this amount of wall time + pub async fn delay(&self, duration: Duration) { + RuntimeTrait::delay(self, duration).await + } + + /// Returns a [`Stream`] that yields a `()` on the provided period + pub fn interval(&self, period: Duration) -> impl Stream + Send + 'static { + let (send, receive) = async_channel::bounded(1); + thread::spawn(move || loop { + thread::sleep(period); + if send.is_closed() { + break; + } + let _ = send.send_blocking(()); + }); + + receive + } + + /// Runtime implementation hook for blocking on a top level future. + pub fn block_on(&self, fut: Fut) -> Fut::Output { + future::block_on(fut) + } + + /// Race a future against the provided duration, returning None in case of timeout. + pub async fn timeout(&self, duration: Duration, fut: Fut) -> Option + where + Fut: Future + Send, + Fut::Output: Send + 'static, + { + RuntimeTrait::timeout(self, duration, fut).await + } +} diff --git a/testing/src/runtimeless/server.rs b/testing/src/runtimeless/server.rs new file mode 100644 index 0000000000..597b176182 --- /dev/null +++ b/testing/src/runtimeless/server.rs @@ -0,0 +1,84 @@ +use super::SERVERS; +use crate::{RuntimelessRuntime, TestTransport}; +use async_channel::Receiver; +use std::io::{Error, ErrorKind, Result}; +use trillium::Info; +use trillium_server_common::{Acceptor, Config, ConfigExt, Server}; +use url::Url; + +/// A [`Server`] for testing that does not depend on any runtime +#[derive(Debug)] +pub struct RuntimelessServer { + host: String, + port: u16, + channel: Receiver, +} + +impl RuntimelessServer { + /// returns whether there are any currently registered servers + pub fn is_empty() -> bool { + SERVERS.is_empty() + } + + /// returns the number of currently registered servers + pub fn len() -> usize { + SERVERS.len() + } +} + +impl Server for RuntimelessServer { + type Transport = TestTransport; + type Runtime = RuntimelessRuntime; + + const DESCRIPTION: &'static str = "test server"; + + fn runtime() -> Self::Runtime { + RuntimelessRuntime::default() + } + + async fn accept(&mut self) -> Result { + self.channel + .recv() + .await + .map_err(|e| Error::new(ErrorKind::Other, e.to_string())) + } + + fn build_listener(config: &Config) -> Self + where + A: Acceptor, + { + let mut port = config.port(); + let host = config.host(); + if port == 0 { + loop { + port = fastrand::u16(..); + if !SERVERS.contains_key(&(host.clone(), port)) { + break; + } + } + } + + let entry = SERVERS + .entry((host.clone(), port)) + .or_insert_with(async_channel::unbounded); + + let (_, channel) = entry.value(); + + Self { + host, + channel: channel.clone(), + port, + } + } + + async fn clean_up(self) { + SERVERS.remove(&(self.host, self.port)); + } + + fn info(&self) -> Info { + let mut info = Info::from(&*format!("{}:{}", &self.host, &self.port)); + info.state_mut() + .insert(Url::parse(&format!("http://{}:{}", &self.host, self.port)).unwrap()); + info + } +} diff --git a/testing/src/server_connector.rs b/testing/src/server_connector.rs index 916975d781..7b007b26f4 100644 --- a/testing/src/server_connector.rs +++ b/testing/src/server_connector.rs @@ -1,21 +1,23 @@ -use crate::TestTransport; -use std::sync::Arc; +use crate::{RuntimeType, TestTransport}; +use std::{io, sync::Arc}; +use trillium::Handler; +use trillium_http::Conn; +use trillium_server_common::Connector; use url::Url; /// a bridge between trillium servers and clients #[derive(Debug)] pub struct ServerConnector { handler: Arc, + runtime: RuntimeType, } -impl ServerConnector -where - H: trillium::Handler, -{ +impl ServerConnector { /// builds a new ServerConnector pub fn new(handler: H) -> Self { Self { handler: Arc::new(handler), + runtime: RuntimeType::default(), } } @@ -25,8 +27,8 @@ where let handler = Arc::clone(&self.handler); - crate::spawn(async move { - trillium_http::Conn::map(server_transport, Default::default(), |mut conn| { + self.runtime.spawn(async move { + Conn::map(server_transport, Default::default(), |mut conn| { let handler = Arc::clone(&handler); async move { conn.set_secure(secure); @@ -43,85 +45,20 @@ where } } -impl trillium_server_common::Connector for ServerConnector { +impl Connector for ServerConnector { type Transport = TestTransport; - async fn connect(&self, url: &Url) -> std::io::Result { + type Runtime = RuntimeType; + async fn connect(&self, url: &Url) -> io::Result { Ok(self.connect(url.scheme() == "https").await) } - fn spawn + Send + 'static>(&self, fut: Fut) { - crate::spawn(fut); - } - - async fn delay(&self, duration: std::time::Duration) { - crate::delay(duration).await + fn runtime(&self) -> Self::Runtime { + #[allow(clippy::clone_on_copy)] + self.runtime.clone() } } /// build a connector from this handler -pub fn connector(handler: impl trillium::Handler) -> impl trillium_server_common::Connector { +pub fn connector(handler: impl Handler) -> impl Connector { ServerConnector::new(handler) } - -#[cfg(test)] -mod test { - use crate::server_connector::ServerConnector; - use trillium_client::Client; - - #[test] - fn test() { - crate::block_on(async { - let client = Client::new(ServerConnector::new("test")); - let mut conn = client.get("https://example.com/test").await.unwrap(); - assert_eq!(conn.response_body().read_string().await.unwrap(), "test"); - }); - } - - #[test] - fn test_no_dns() { - crate::block_on(async { - let client = Client::new(ServerConnector::new("test")); - let mut conn = client - .get("https://not.a.real.tld.example/test") - .await - .unwrap(); - assert_eq!(conn.response_body().read_string().await.unwrap(), "test"); - }); - } - - #[test] - fn test_post() { - crate::block_on(async { - let client = Client::new(ServerConnector::new( - |mut conn: trillium::Conn| async move { - let body = conn.request_body_string().await.unwrap(); - let response = format!( - "{} {}://{}{} with body \"{}\"", - conn.method(), - if conn.is_secure() { "https" } else { "http" }, - conn.inner().host().unwrap_or_default(), - conn.path(), - body - ); - - conn.ok(response) - }, - )); - - let body = client - .post("https://example.com/test") - .with_body("some body") - .await - .unwrap() - .response_body() - .read_string() - .await - .unwrap(); - - assert_eq!( - body, - "POST https://example.com/test with body \"some body\"" - ); - }); - } -} diff --git a/testing/src/with_server.rs b/testing/src/with_server.rs index 6c898c0434..4b7234acdf 100644 --- a/testing/src/with_server.rs +++ b/testing/src/with_server.rs @@ -1,7 +1,8 @@ -use crate::{block_on, ServerConnector}; +use crate::{block_on, config, ServerConnector}; use std::{error::Error, future::Future}; use trillium::Handler; use trillium_http::transport::BoxedTransport; +use trillium_server_common::RuntimeTrait; use url::Url; /** @@ -20,14 +21,15 @@ where Fun: FnOnce(Url) -> Fut, Fut: Future>>, { - block_on(async move { - let port = portpicker::pick_unused_port().expect("could not pick a port"); - let url = format!("http://localhost:{port}").parse().unwrap(); - let handle = crate::config() - .with_host("localhost") - .with_port(port) - .spawn(handler); - handle.info().await; + let config = config().with_host("localhost").with_port(0); + let runtime = config.runtime(); + runtime.block_on(async move { + let handle = config.spawn(handler); + let info = handle.info().await; + let url = info.state().get::().cloned().unwrap_or_else(|| { + let port = info.tcp_socket_addr().map(|t| t.port()).unwrap_or(0); + format!("http://localhost:{port}").parse().unwrap() + }); tests(url).await.unwrap(); handle.shut_down().await; }); diff --git a/testing/tests/runtimeless.rs b/testing/tests/runtimeless.rs new file mode 100644 index 0000000000..b5310ad3d0 --- /dev/null +++ b/testing/tests/runtimeless.rs @@ -0,0 +1,36 @@ +use test_harness::test; +use trillium_client::Client; +use trillium_server_common::Config; +use trillium_testing::{harness, RuntimelessClientConfig, RuntimelessServer, TestResult}; +#[test(harness)] +async fn round_trip() -> TestResult { + let handle1 = Config::::new() + .with_host("host.com") + .with_port(80) + .spawn("server 1"); + handle1.info().await; + + let handle2 = Config::::new() + .with_host("other_host.com") + .with_port(80) + .spawn("server 2"); + handle2.info().await; + + let client = Client::new(RuntimelessClientConfig::default()); + let mut conn = client.get("http://host.com").await?; + assert_eq!(conn.response_body().await?, "server 1"); + + let mut conn = client.get("http://other_host.com").await?; + assert_eq!(conn.response_body().await?, "server 2"); + + handle1.shut_down().await; + assert!(client.get("http://host.com").await.is_err()); + assert!(client.get("http://other_host.com").await.is_ok()); + + handle2.shut_down().await; + assert!(client.get("http://other_host.com").await.is_err()); + + assert!(RuntimelessServer::is_empty()); + + Ok(()) +} diff --git a/testing/tests/server_connector.rs b/testing/tests/server_connector.rs new file mode 100644 index 0000000000..519b3cab38 --- /dev/null +++ b/testing/tests/server_connector.rs @@ -0,0 +1,54 @@ +use test_harness::test; +use trillium_client::Client; +use trillium_testing::{harness, ServerConnector}; + +#[test(harness)] +async fn test() { + let client = Client::new(ServerConnector::new("test")); + let mut conn = client.get("https://example.com/test").await.unwrap(); + assert_eq!(conn.response_body().read_string().await.unwrap(), "test"); +} + +#[test(harness)] +async fn test_no_dns() { + let client = Client::new(ServerConnector::new("test")); + let mut conn = client + .get("https://not.a.real.tld.example/test") + .await + .unwrap(); + assert_eq!(conn.response_body().read_string().await.unwrap(), "test"); +} + +#[test(harness)] +async fn test_post() { + let client = Client::new(ServerConnector::new( + |mut conn: trillium::Conn| async move { + let body = conn.request_body_string().await.unwrap(); + let response = format!( + "{} {}://{}{} with body \"{}\"", + conn.method(), + if conn.is_secure() { "https" } else { "http" }, + conn.inner().host().unwrap_or_default(), + conn.path(), + body + ); + + conn.ok(response) + }, + )); + + let body = client + .post("https://example.com/test") + .with_body("some body") + .await + .unwrap() + .response_body() + .read_string() + .await + .unwrap(); + + assert_eq!( + body, + "POST https://example.com/test with body \"some body\"" + ); +} diff --git a/testing/tests/spawn.rs b/testing/tests/spawn.rs index f1ce60e386..d4d5ac6a1d 100644 --- a/testing/tests/spawn.rs +++ b/testing/tests/spawn.rs @@ -1,32 +1,41 @@ +use std::time::Duration; use test_harness::test; +use trillium_testing::{harness, runtime, RuntimeTrait, TestResult}; -#[test(harness = trillium_testing::harness)] -async fn spawn_works() -> trillium_testing::TestResult { - let fut = trillium_testing::spawn(async { - std::thread::sleep(std::time::Duration::from_secs(1)); +#[test(harness)] +async fn spawn_works() -> TestResult { + let runtime = runtime(); + let rt = runtime.clone(); + let fut = rt.spawn(async move { + let runtime = runtime; + runtime.delay(Duration::from_secs(1)).await; 1 }); assert_eq!(1, fut.await.unwrap()); Ok(()) } -#[test(harness = trillium_testing::harness)] -async fn dropped_spawn_task_still_finishes() -> trillium_testing::TestResult { +#[test(harness)] +async fn dropped_spawn_task_still_finishes() -> TestResult { let (tx, rx) = async_channel::unbounded(); - drop(trillium_testing::spawn(async move { - std::thread::sleep(std::time::Duration::from_secs(1)); + + runtime().spawn(async move { + runtime().delay(Duration::from_secs(1)).await; tx.send(1).await.unwrap(); 2 - })); + }); assert_eq!(1, rx.recv().await.unwrap()); Ok(()) } -#[test(harness = trillium_testing::harness)] -async fn panic_in_spawn_returns_none() -> trillium_testing::TestResult { - let fut = trillium_testing::spawn(async move { - std::thread::sleep(std::time::Duration::from_secs(1)); +#[test(harness)] +async fn panic_in_spawn_returns_none() -> TestResult { + let runtime = runtime(); + let rt = runtime.clone(); + let fut = rt.spawn(async move { + let runtime = runtime; + runtime.delay(Duration::from_secs(1)).await; panic!(); }); diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index c41fd527fd..5774ad077b 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -22,7 +22,6 @@ trillium-server-common = { path = "../server-common", version = "0.5.2" } [dependencies.tokio] version = "1.35.1" features = ["rt", "net", "fs", "rt-multi-thread", "time"] -package = "tokio" [target.'cfg(unix)'.dependencies] signal-hook = "0.3.17" @@ -30,4 +29,4 @@ signal-hook-tokio = { version = "0.3.1", features = ["futures-v0_3"] } [dev-dependencies] env_logger = "0.11.0" -tokio = { version = "1.35.1", features = ["full"], package = "tokio" } +tokio = { version = "1.35.1", features = ["full"] } diff --git a/tokio/src/client.rs b/tokio/src/client.rs index dbfabe0c2b..f03736f914 100644 --- a/tokio/src/client.rs +++ b/tokio/src/client.rs @@ -1,7 +1,6 @@ -use crate::TokioTransport; +use crate::{TokioRuntime, TokioTransport}; use async_compat::Compat; use std::{ - future::Future, io::{Error, ErrorKind, Result}, time::Duration, }; @@ -59,6 +58,7 @@ impl ClientConfig { impl Connector for ClientConfig { type Transport = TokioTransport>; + type Runtime = TokioRuntime; async fn connect(&self, url: &Url) -> Result { if url.scheme() != "http" { @@ -98,11 +98,7 @@ impl Connector for ClientConfig { Ok(tcp) } - fn spawn + Send + 'static>(&self, fut: Fut) { - tokio::task::spawn(fut); - } - - async fn delay(&self, duration: Duration) { - tokio::time::sleep(duration).await + fn runtime(&self) -> Self::Runtime { + TokioRuntime::default() } } diff --git a/tokio/src/lib.rs b/tokio/src/lib.rs index 5fa79d71c2..500344a845 100644 --- a/tokio/src/lib.rs +++ b/tokio/src/lib.rs @@ -31,8 +31,6 @@ async fn main() { ``` */ -use std::future::Future; - use trillium::Handler; pub use trillium_server_common::{Binding, Swansong}; @@ -116,14 +114,5 @@ pub fn config() -> Config<()> { Config::new() } -/** -reexport tokio runtime block_on -*/ -pub fn block_on, T>(future: Fut) -> T { - tokio::runtime::Runtime::new().unwrap().block_on(future) -} - -/// spawn and detach a Future that returns () -pub fn spawn + Send + 'static>(future: Fut) { - tokio::task::spawn(future); -} +mod runtime; +pub use runtime::TokioRuntime; diff --git a/tokio/src/runtime.rs b/tokio/src/runtime.rs new file mode 100644 index 0000000000..5520d0ee96 --- /dev/null +++ b/tokio/src/runtime.rs @@ -0,0 +1,117 @@ +use std::{future::Future, sync::Arc, time::Duration}; +use tokio::{runtime::Handle, time}; +use tokio_stream::{wrappers::IntervalStream, Stream, StreamExt}; +use trillium_server_common::{DroppableFuture, Runtime, RuntimeTrait}; + +#[derive(Debug, Clone)] +enum Inner { + AlreadyRunning(Handle), + Owned(Arc), +} + +/// tokio runtime +#[derive(Clone, Debug)] +pub struct TokioRuntime(Inner); + +impl Default for TokioRuntime { + fn default() -> Self { + if let Ok(handle) = Handle::try_current() { + Self(Inner::AlreadyRunning(handle)) + } else { + Self(Inner::Owned(Arc::new( + tokio::runtime::Runtime::new().unwrap(), + ))) + } + } +} + +impl RuntimeTrait for TokioRuntime { + fn spawn( + &self, + fut: Fut, + ) -> DroppableFuture> + Send + 'static> + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let join_handle = match &self.0 { + Inner::AlreadyRunning(handle) => handle.spawn(fut), + Inner::Owned(runtime) => runtime.spawn(fut), + }; + DroppableFuture::new(async move { join_handle.await.ok() }) + } + + async fn delay(&self, duration: Duration) { + time::sleep(duration).await; + } + + fn interval(&self, period: Duration) -> impl Stream + Send + 'static { + IntervalStream::new(time::interval(period)).map(|_| ()) + } + + fn block_on(&self, fut: Fut) -> Fut::Output { + match &self.0 { + Inner::AlreadyRunning(handle) => handle.block_on(fut), + Inner::Owned(runtime) => runtime.block_on(fut), + } + } +} + +impl TokioRuntime { + /// Spawn a future on the runtime, returning a future that has detach-on-drop semantics + /// + /// Spawned tasks conform to the following behavior: + /// + /// * detach on drop: If the returned [`DroppableFuture`] is dropped immediately, the task will + /// continue to execute until completion. + /// + /// * unwinding: If the spawned future panics, this must not propagate to the join + /// handle. Instead, the awaiting the join handle returns None in case of panic. + pub fn spawn( + &self, + fut: Fut, + ) -> DroppableFuture> + Send + 'static> + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let join_handle = match &self.0 { + Inner::AlreadyRunning(handle) => handle.spawn(fut), + Inner::Owned(runtime) => runtime.spawn(fut), + }; + DroppableFuture::new(async move { join_handle.await.ok() }) + } + + /// wake in this amount of wall time + pub async fn delay(&self, duration: Duration) { + time::sleep(duration).await; + } + + /// Returns a [`Stream`] that yields a `()` on the provided period + pub fn interval(&self, period: Duration) -> impl Stream + Send + 'static { + IntervalStream::new(time::interval(period)).map(|_| ()) + } + + /// Runtime implementation hook for blocking on a top level future. + pub fn block_on(&self, fut: Fut) -> Fut::Output { + match &self.0 { + Inner::AlreadyRunning(handle) => handle.block_on(fut), + Inner::Owned(runtime) => runtime.block_on(fut), + } + } + + /// Race a future against the provided duration, returning None in case of timeout. + pub async fn timeout(&self, duration: Duration, fut: Fut) -> Option + where + Fut: Future + Send, + Fut::Output: Send + 'static, + { + RuntimeTrait::timeout(self, duration, fut).await + } +} + +impl From for Runtime { + fn from(value: TokioRuntime) -> Self { + Runtime::new(value) + } +} diff --git a/tokio/src/server.rs b/tokio/src/server.rs index cc96a2a15d..7ea887c3c5 100644 --- a/tokio/src/server.rs +++ b/tokio/src/server.rs @@ -3,7 +3,6 @@ mod unix; #[cfg(unix)] pub use unix::TokioServer; -//#[cfg(not(unix))] mod tcp; #[cfg(not(unix))] pub use tcp::TokioServer; diff --git a/tokio/src/server/tcp.rs b/tokio/src/server/tcp.rs index 85b453ff76..ddc0201f5c 100644 --- a/tokio/src/server/tcp.rs +++ b/tokio/src/server/tcp.rs @@ -1,10 +1,7 @@ -use crate::TokioTransport; +use crate::{TokioRuntime, TokioTransport}; use async_compat::Compat; -use std::{future::Future, io::Result}; -use tokio::{ - net::{TcpListener, TcpStream}, - spawn, -}; +use std::{io, net}; +use tokio::net::{TcpListener, TcpStream}; use trillium::Info; use trillium_server_common::Server; @@ -19,6 +16,7 @@ impl From for TokioServer { } impl Server for TokioServer { + type Runtime = TokioRuntime; type Transport = TokioTransport>; const DESCRIPTION: &'static str = concat!( " (", @@ -28,14 +26,14 @@ impl Server for TokioServer { ")" ); - async fn accept(&mut self) -> Result { + async fn accept(&mut self) -> io::Result { self.0 .accept() .await .map(|(t, _)| TokioTransport(Compat::new(t))) } - fn listener_from_tcp(tcp: std::net::TcpListener) -> Self { + fn listener_from_tcp(tcp: net::TcpListener) -> Self { Self(tcp.try_into().unwrap()) } @@ -43,11 +41,7 @@ impl Server for TokioServer { self.0.local_addr().unwrap().into() } - fn spawn(fut: impl Future + Send + 'static) { - spawn(fut); - } - - fn block_on(fut: impl Future + 'static) { - crate::block_on(fut); + fn runtime() -> Self::Runtime { + TokioRuntime::default() } } diff --git a/tokio/src/server/unix.rs b/tokio/src/server/unix.rs index fa9fb5527c..674da9a4c3 100644 --- a/tokio/src/server/unix.rs +++ b/tokio/src/server/unix.rs @@ -1,10 +1,7 @@ -use crate::TokioTransport; +use crate::{TokioRuntime, TokioTransport}; use async_compat::Compat; -use std::{future::Future, io::Result}; -use tokio::{ - net::{TcpListener, TcpStream, UnixListener, UnixStream}, - spawn, -}; +use std::io::Result; +use tokio::net::{TcpListener, TcpStream, UnixListener, UnixStream}; use trillium::{log_error, Info}; use trillium_server_common::{ Binding::{self, *}, @@ -28,7 +25,9 @@ impl From for TokioServer { } impl Server for TokioServer { + type Runtime = TokioRuntime; type Transport = Binding>, TokioTransport>>; + const DESCRIPTION: &'static str = concat!( " (", env!("CARGO_PKG_NAME"), @@ -75,14 +74,6 @@ impl Server for TokioServer { } } - fn spawn(fut: impl Future + Send + 'static) { - spawn(fut); - } - - fn block_on(fut: impl Future + 'static) { - crate::block_on(fut) - } - fn listener_from_tcp(tcp: std::net::TcpListener) -> Self { Self(Tcp(tcp.try_into().unwrap())) } @@ -101,4 +92,8 @@ impl Server for TokioServer { } } } + + fn runtime() -> Self::Runtime { + TokioRuntime::default() + } } diff --git a/tokio/tests/tests.rs b/tokio/tests/tests.rs index 2d2a697b32..95861f5345 100644 --- a/tokio/tests/tests.rs +++ b/tokio/tests/tests.rs @@ -1,20 +1,30 @@ -#[cfg(feature = "tokio")] +use trillium::Swansong; +use trillium_tokio::config; + +#[tokio::test] +async fn spawn_async() { + config().with_port(0).spawn(()).shut_down().await; +} + #[test] -fn smoke() { - use trillium_testing::prelude::*; +fn spawn_block() { + config().with_port(0).spawn(()).shut_down().block(); +} - fn app() -> impl trillium::Handler { - |conn: trillium::Conn| async move { - let response = tokio::task::spawn(async { - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - "successfully spawned a task" - }) - .await - .unwrap(); - conn.ok(response) - } - } +#[test] +fn run() { + let swansong = Swansong::new(); + swansong.shut_down(); + config().with_port(0).with_swansong(swansong).run(()); +} - let app = app(); - assert_ok!(get("/").on(&app), "successfully spawned a task"); +#[tokio::test] +async fn run_async() { + let swansong = Swansong::new(); + swansong.shut_down(); + config() + .with_port(0) + .with_swansong(swansong) + .run_async(()) + .await; } diff --git a/trillium/tests/liveness.rs b/trillium/tests/liveness.rs index d4a3060983..d6fa2b372e 100644 --- a/trillium/tests/liveness.rs +++ b/trillium/tests/liveness.rs @@ -4,16 +4,15 @@ use std::{ io, pin::Pin, task::{Context, Poll}, - thread, time::Duration, }; use test_harness::test; use trillium::Conn; -use trillium_testing::{config, harness, ArcedConnector, ClientConfig, Connector, TestResult}; +use trillium_testing::{client_config, config, harness, ArcedConnector, Connector, TestResult}; #[test(harness)] async fn infinitely_pending_task() -> TestResult { - let connector = ArcedConnector::new(ClientConfig::default()); + let connector = ArcedConnector::new(client_config()); let handle = config() .with_host("localhost") @@ -50,7 +49,7 @@ async fn infinitely_pending_task() -> TestResult { #[test(harness)] async fn is_disconnected() -> TestResult { let _ = env_logger::builder().is_test(true).try_init(); - let connector = ArcedConnector::new(ClientConfig::default()); + let connector = ArcedConnector::new(client_config()); let (delay_sender, delay_receiver) = async_channel::unbounded(); let (disconnected_sender, disconnected_receiver) = async_channel::unbounded(); let handle = config() @@ -70,6 +69,7 @@ async fn is_disconnected() -> TestResult { }); let info = handle.info().await; + let runtime = handle.runtime(); let url = format!("http://{}", info.listener_description()) .parse() @@ -93,7 +93,7 @@ async fn is_disconnected() -> TestResult { .write_all(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") .await?; drop(client); - thread::sleep(Duration::from_millis(10)); + runtime.delay(Duration::from_millis(10)).await; delay_sender.send(()).await?; assert!(disconnected_receiver.recv().await?);