Skip to content

Commit

Permalink
Add try_write to TcpStream (#176)
Browse files Browse the repository at this point in the history
Similar to the existing write behavior these never poll, which can cause busy loops. Future changes will address this, but having the API gap filled is worthwhile at this time.
  • Loading branch information
brandonpike authored Jul 15, 2024
1 parent 766108f commit 0c984a0
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 13 deletions.
49 changes: 40 additions & 9 deletions src/net/tcp/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,23 @@ impl TcpStream {
Ok(TcpStream::new(pair, rx))
}

/// Try to write a buffer to the stream, returning how many bytes were
/// written.
///
/// The function will attempt to write the entire contents of `buf`, but
/// only part of the buffer may be written.
///
/// This function is usually paired with `writable()`.
///
/// # Return
///
/// If data is successfully written, `Ok(n)` is returned, where `n` is the
/// number of bytes written. If the stream is not ready to write data,
/// `Err(io::ErrorKind::WouldBlock)` is returned.
pub fn try_write(&self, buf: &[u8]) -> Result<usize> {
self.write_half.try_write(buf)
}

/// Returns the local address that this stream is bound to.
pub fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.read_half.pair.local)
Expand All @@ -110,6 +127,21 @@ impl TcpStream {
}
}

/// Waits for the socket to become writable.
///
/// This function is equivalent to `ready(Interest::WRITABLE)` and is usually
/// paired with `try_write()`.
///
/// # Cancel safety
///
/// This method is cancel safe. Once a readiness event occurs, the method
/// will continue to return immediately until the readiness event is
/// consumed by an attempt to write that fails with `WouldBlock` or
/// `Poll::Pending`.
pub async fn writable(&self) -> Result<()> {
Ok(())
}

/// Splits a `TcpStream` into a read half and a write half, which can be used
/// to read and write the stream concurrently.
///
Expand Down Expand Up @@ -220,29 +252,28 @@ pub(crate) struct WriteHalf {
}

impl WriteHalf {
fn poll_write_priv(&self, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
fn try_write(&self, buf: &[u8]) -> Result<usize> {
if buf.remaining() == 0 {
return Poll::Ready(Ok(0));
return Ok(0);
}

if self.is_shutdown {
return Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"Broken pipe",
)));
return Err(io::Error::new(io::ErrorKind::BrokenPipe, "Broken pipe"));
}

let res = World::current(|world| {
World::current(|world| {
let bytes = Bytes::copy_from_slice(buf);
let len = bytes.len();

let seq = self.seq(world)?;
self.send(world, Segment::Data(seq, bytes))?;

Ok(len)
});
})
}

Poll::Ready(res)
fn poll_write_priv(&self, _cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
Poll::Ready(self.try_write(buf))
}

fn poll_shutdown_priv(&mut self) -> Poll<Result<()>> {
Expand Down
5 changes: 1 addition & 4 deletions tests/async_send_sync.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
//! Copied over from:
//! https://github.com/tokio-rs/tokio/blob/master/tokio/tests/async_send_sync.rs
#![allow(dead_code)]

#[allow(dead_code)]
fn require_send<T: Send>(_t: &T) {}
#[allow(dead_code)]
fn require_sync<T: Sync>(_t: &T) {}
#[allow(dead_code)]
fn require_unpin<T: Unpin>(_t: &T) {}

#[allow(dead_code)]
struct Invalid;

trait AmbiguousIfSend<A> {
Expand Down
24 changes: 24 additions & 0 deletions tests/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1162,3 +1162,27 @@ fn exhaust_ephemeral_ports() {

_ = sim.run()
}

#[test]
fn try_write() -> Result {
let mut sim = Builder::new().build();
sim.client("client", async move {
let listener = TcpListener::bind((Ipv4Addr::LOCALHOST, 1234)).await?;

tokio::spawn(async move {
let (socket, _) = listener.accept().await.unwrap();

let written = socket.try_write(b"hello!").unwrap();
assert_eq!(written, 6);
});

let mut socket = TcpStream::connect((Ipv4Addr::LOCALHOST, 1234)).await?;
let mut buf: [u8; 6] = [0; 6];
socket.read_exact(&mut buf).await?;
assert_eq!(&buf, b"hello!");

Ok(())
});

sim.run()
}

0 comments on commit 0c984a0

Please sign in to comment.