Skip to content

Commit

Permalink
copy: fix handling of blocked writers (istio#1250)
Browse files Browse the repository at this point in the history
Fixes istio#1196

This adds a test which tests how the copy logic behaves in the face of
writers that do not always accept all the data we are writing. This
exposes two issues:
* If the write is Pending entirely, we dropped the saved buffer.
* If we did a partial write, our check was backwards causing us to not
  save the buffer.

Co-authored-by: John Howard <[email protected]>
  • Loading branch information
istio-testing and howardjohn authored Aug 2, 2024
1 parent 25efb39 commit 3f6e6cb
Showing 1 changed file with 97 additions and 3 deletions.
100 changes: 97 additions & 3 deletions src/copy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -287,11 +287,17 @@ where

// This is just a reference counter. Hold onto it in case the write() is not complete.
let mut our_copy = buffer.clone();
let i = ready!(Pin::new(&mut *me.writer).poll_write_buf(cx, buffer))?;
let i = match Pin::new(&mut *me.writer).poll_write_buf(cx, buffer) {
Poll::Ready(written) => written?,
Poll::Pending => {
me.buf = Some(our_copy);
return Poll::Pending;
}
};
if i == 0 {
return Poll::Ready(Err(std::io::ErrorKind::WriteZero.into()));
}
if our_copy.len() < i {
if i < our_copy.len() {
// We only partially consumed it; store it back for a future call, skipping the number of bytes we did read.
our_copy.advance(i);
me.buf = Some(our_copy);
Expand Down Expand Up @@ -432,8 +438,9 @@ where
mod tests {
use super::*;
use crate::test_helpers::helpers::initialize_telemetry;
use tokio::io::AsyncReadExt;
use rand::Rng;
use tokio::io::AsyncWriteExt;
use tokio::io::{AsyncReadExt, ReadBuf};

#[tokio::test]
async fn copy() {
Expand Down Expand Up @@ -479,4 +486,91 @@ mod tests {
assert_eq!(res.as_slice(), body);
}
}

#[tokio::test]
async fn copystress() {
initialize_telemetry();
let (mut client, ztunnel_downsteam) = tokio::io::duplex(32000);
let (mut server, ztunnel_upsteam) = tokio::io::duplex(32000);

// Spawn copy
tokio::task::spawn(async move {
let mut registry = prometheus_client::registry::Registry::default();
let metrics = std::sync::Arc::new(crate::proxy::Metrics::new(
crate::metrics::sub_registry(&mut registry),
));
let source_addr = "127.0.0.1:12345".parse().unwrap();
let dest_addr = "127.0.0.1:34567".parse().unwrap();
let cr = ConnectionResult::new(
source_addr,
dest_addr,
None,
std::time::Instant::now(),
crate::proxy::metrics::ConnectionOpen {
reporter: crate::proxy::Reporter::destination,
source: None,
derived_source: None,
destination: None,
connection_security_policy: crate::proxy::metrics::SecurityPolicy::unknown,
destination_service: None,
},
metrics.clone(),
);
copy_bidirectional(WeirdIO(ztunnel_downsteam), WeirdIO(ztunnel_upsteam), &cr).await
});
const WRITES: usize = 2560;
// Do a bunch of writes of various size, and expect the other end to receive them
let writer = tokio::task::spawn(async move {
for d in 0..WRITES {
let body: Vec<u8> = (0..d).map(|v| (v % 255) as u8).collect();
client.write_all(&body).await.unwrap();
}
});
let reader = tokio::task::spawn(async move {
for d in 0..WRITES {
let want: Vec<u8> = (0..d).map(|v| (v % 255) as u8).collect();
let mut got = vec![0; d];
server.read_exact(&mut got).await.unwrap();
assert_eq!(got.as_slice(), want);
}
});
tokio::try_join!(reader, writer).unwrap();
}

struct WeirdIO<I>(I);
impl<I: AsyncWrite + std::marker::Unpin> AsyncWrite for WeirdIO<I> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, Error>> {
if buf.is_empty() {
return Poll::Ready(Ok(0));
}
let mut rng = rand::thread_rng();
let end = rng.gen_range(1..=buf.len()); // Ensure at least 1 byte is written
Pin::new(&mut self.0).poll_write(cx, &buf[0..end])
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Result<(), Error>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}
}
impl<I: AsyncRead + std::marker::Unpin> AsyncRead for WeirdIO<I> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<std::io::Result<()>> {
// TODO
Pin::new(&mut self.0).poll_read(cx, buf)
}
}
}

0 comments on commit 3f6e6cb

Please sign in to comment.