From 0c984a0878319cf5f9ab8fc0ee0f168b29331ec1 Mon Sep 17 00:00:00 2001 From: brandonpike Date: Mon, 15 Jul 2024 11:41:49 -0500 Subject: [PATCH] Add try_write to TcpStream (#176) Similar to the existing write behavior these never poll, which can cause busy loops. Future changes will address this, but having the API gap filled is worthwhile at this time. --- src/net/tcp/stream.rs | 49 ++++++++++++++++++++++++++++++++-------- tests/async_send_sync.rs | 5 +--- tests/tcp.rs | 24 ++++++++++++++++++++ 3 files changed, 65 insertions(+), 13 deletions(-) diff --git a/src/net/tcp/stream.rs b/src/net/tcp/stream.rs index 728a094..b60748b 100644 --- a/src/net/tcp/stream.rs +++ b/src/net/tcp/stream.rs @@ -93,6 +93,23 @@ impl TcpStream { Ok(TcpStream::new(pair, rx)) } + /// Try to write a buffer to the stream, returning how many bytes were + /// written. + /// + /// The function will attempt to write the entire contents of `buf`, but + /// only part of the buffer may be written. + /// + /// This function is usually paired with `writable()`. + /// + /// # Return + /// + /// If data is successfully written, `Ok(n)` is returned, where `n` is the + /// number of bytes written. If the stream is not ready to write data, + /// `Err(io::ErrorKind::WouldBlock)` is returned. + pub fn try_write(&self, buf: &[u8]) -> Result { + self.write_half.try_write(buf) + } + /// Returns the local address that this stream is bound to. pub fn local_addr(&self) -> Result { Ok(self.read_half.pair.local) @@ -110,6 +127,21 @@ impl TcpStream { } } + /// Waits for the socket to become writable. + /// + /// This function is equivalent to `ready(Interest::WRITABLE)` and is usually + /// paired with `try_write()`. + /// + /// # Cancel safety + /// + /// This method is cancel safe. Once a readiness event occurs, the method + /// will continue to return immediately until the readiness event is + /// consumed by an attempt to write that fails with `WouldBlock` or + /// `Poll::Pending`. + pub async fn writable(&self) -> Result<()> { + Ok(()) + } + /// Splits a `TcpStream` into a read half and a write half, which can be used /// to read and write the stream concurrently. /// @@ -220,19 +252,16 @@ pub(crate) struct WriteHalf { } impl WriteHalf { - fn poll_write_priv(&self, _cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + fn try_write(&self, buf: &[u8]) -> Result { if buf.remaining() == 0 { - return Poll::Ready(Ok(0)); + return Ok(0); } if self.is_shutdown { - return Poll::Ready(Err(io::Error::new( - io::ErrorKind::BrokenPipe, - "Broken pipe", - ))); + return Err(io::Error::new(io::ErrorKind::BrokenPipe, "Broken pipe")); } - let res = World::current(|world| { + World::current(|world| { let bytes = Bytes::copy_from_slice(buf); let len = bytes.len(); @@ -240,9 +269,11 @@ impl WriteHalf { self.send(world, Segment::Data(seq, bytes))?; Ok(len) - }); + }) + } - Poll::Ready(res) + fn poll_write_priv(&self, _cx: &mut Context<'_>, buf: &[u8]) -> Poll> { + Poll::Ready(self.try_write(buf)) } fn poll_shutdown_priv(&mut self) -> Poll> { diff --git a/tests/async_send_sync.rs b/tests/async_send_sync.rs index d3a11e1..c37259f 100644 --- a/tests/async_send_sync.rs +++ b/tests/async_send_sync.rs @@ -1,14 +1,11 @@ //! Copied over from: //! https://github.com/tokio-rs/tokio/blob/master/tokio/tests/async_send_sync.rs +#![allow(dead_code)] -#[allow(dead_code)] fn require_send(_t: &T) {} -#[allow(dead_code)] fn require_sync(_t: &T) {} -#[allow(dead_code)] fn require_unpin(_t: &T) {} -#[allow(dead_code)] struct Invalid; trait AmbiguousIfSend { diff --git a/tests/tcp.rs b/tests/tcp.rs index d96cc8c..42f8c5b 100644 --- a/tests/tcp.rs +++ b/tests/tcp.rs @@ -1162,3 +1162,27 @@ fn exhaust_ephemeral_ports() { _ = sim.run() } + +#[test] +fn try_write() -> Result { + let mut sim = Builder::new().build(); + sim.client("client", async move { + let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 1234)).await?; + + tokio::spawn(async move { + let (socket, _) = listener.accept().await.unwrap(); + + let written = socket.try_write(b"hello!").unwrap(); + assert_eq!(written, 6); + }); + + let mut socket = TcpStream::connect((Ipv4Addr::LOCALHOST, 1234)).await?; + let mut buf: [u8; 6] = [0; 6]; + socket.read_exact(&mut buf).await?; + assert_eq!(&buf, b"hello!"); + + Ok(()) + }); + + sim.run() +}