Skip to content

Commit

Permalink
handle various usage of unwrap to proper error handling (#880)
Browse files Browse the repository at this point in the history
For #9
  • Loading branch information
howardjohn authored Apr 7, 2024
1 parent dd08055 commit 961eb56
Show file tree
Hide file tree
Showing 11 changed files with 112 additions and 94 deletions.
89 changes: 38 additions & 51 deletions src/admin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use hyper::{header::HeaderValue, header::CONTENT_TYPE, Request, Response};
use pprof::protos::Message;
use std::borrow::Borrow;
use std::collections::HashMap;

use std::str::FromStr;
use std::sync::Arc;
use std::time::SystemTime;
Expand Down Expand Up @@ -122,25 +123,24 @@ impl Service {
pub fn spawn(self) {
self.s.spawn(|state, req| async move {
match req.uri().path() {
"/debug/pprof/profile" => Ok(handle_pprof(req).await),
"/debug/pprof/heap" => Ok(handle_jemalloc_pprof_heapgen(req).await),
"/debug/pprof/profile" => handle_pprof(req).await,
"/debug/pprof/heap" => handle_jemalloc_pprof_heapgen(req).await,
"/quitquitquit" => Ok(handle_server_shutdown(
state.shutdown_trigger.clone(),
req,
state.config.self_termination_deadline,
)
.await),
"/config_dump" => Ok(handle_config_dump(
ConfigDump {
"/config_dump" => {
handle_config_dump(ConfigDump {
proxy_state: state.proxy_state.clone(),
static_config: Default::default(),
version: BuildInfo::new(),
config: state.config.clone(),
certificates: dump_certs(state.cert_manager.borrow()).await,
},
// req, // bring this back if we start using it
)
.await),
})
.await
}
"/logging" => Ok(handle_logging(req).await),
"/" => Ok(handle_dashboard(req, &state.handlers).await),
_ => match Self::find_handler(state.as_ref(), req.uri().path()) {
Expand Down Expand Up @@ -245,30 +245,22 @@ async fn dump_certs(cert_manager: &SecretManager) -> Vec<CertsDump> {
dump
}

async fn handle_pprof(_req: Request<Incoming>) -> Response<Full<Bytes>> {
async fn handle_pprof(_req: Request<Incoming>) -> anyhow::Result<Response<Full<Bytes>>> {
let guard = pprof::ProfilerGuardBuilder::default()
.frequency(1000)
// .blocklist(&["libc", "libgcc", "pthread", "vdso"])
.build()
.unwrap();
.build()?;

tokio::time::sleep(Duration::from_secs(10)).await;
match guard.report().build() {
Ok(report) => {
let profile = report.pprof().unwrap();
let report = guard.report().build()?;
let profile = report.pprof()?;

let body = profile.write_to_bytes().unwrap();
let body = profile.write_to_bytes()?;

Response::builder()
.status(hyper::StatusCode::OK)
.body(body.into())
.unwrap()
}
Err(err) => plaintext_response(
hyper::StatusCode::INTERNAL_SERVER_ERROR,
format!("failed to build profile: {err}\n"),
),
}
Ok(Response::builder()
.status(hyper::StatusCode::OK)
.body(body.into())
.expect("builder with known status code should not fail"))
}

async fn handle_server_shutdown(
Expand All @@ -291,10 +283,7 @@ async fn handle_server_shutdown(
}
}

async fn handle_config_dump(
mut dump: ConfigDump,
// _req: Request<Incoming>,
) -> Response<Full<Bytes>> {
async fn handle_config_dump(mut dump: ConfigDump) -> anyhow::Result<Response<Full<Bytes>>> {
if let Some(cfg) = dump.config.local_xds_config.clone() {
match cfg.read_to_string().await {
Ok(data) => match serde_yaml::from_str(&data) {
Expand All @@ -311,11 +300,11 @@ async fn handle_config_dump(
}
}

let body = serde_json::to_string_pretty(&dump).unwrap();
Response::builder()
let body = serde_json::to_string_pretty(&dump)?;
Ok(Response::builder()
.status(hyper::StatusCode::OK)
.body(body.into())
.unwrap()
.expect("builder with known status code should not fail"))
}

//mirror envoy's behavior: https://www.envoyproxy.io/docs/envoy/latest/operations/admin#post--logging
Expand Down Expand Up @@ -393,33 +382,31 @@ fn change_log_level(reset: bool, level: &str) -> Response<Full<Bytes>> {
}

#[cfg(feature = "jemalloc")]
async fn handle_jemalloc_pprof_heapgen(_req: Request<Incoming>) -> Response<Full<Bytes>> {
let mut prof_ctl = jemalloc_pprof::PROF_CTL.as_ref().unwrap().lock().await;
async fn handle_jemalloc_pprof_heapgen(
_req: Request<Incoming>,
) -> anyhow::Result<Response<Full<Bytes>>> {
let mut prof_ctl = jemalloc_pprof::PROF_CTL.as_ref()?.lock().await;
if !prof_ctl.activated() {
Response::builder()
return Ok(Response::builder()
.status(hyper::StatusCode::INTERNAL_SERVER_ERROR)
.body("jemalloc not enabled".into())
.unwrap()
} else {
let pprof = prof_ctl.dump_pprof().map_err(|err| {
Response::builder()
.status(hyper::StatusCode::INTERNAL_SERVER_ERROR)
.body(err)
.unwrap()
});
Response::builder()
.status(hyper::StatusCode::OK)
.body(Bytes::from(pprof.unwrap()).into())
.unwrap()
.expect("builder with known status code should not fail"));
}
let pprof = prof_ctl.dump_pprof()?;
Ok(Response::builder()
.status(hyper::StatusCode::OK)
.body(Bytes::from(pprof?).into())
.expect("builder with known status code should not fail"))
}

#[cfg(not(feature = "jemalloc"))]
async fn handle_jemalloc_pprof_heapgen(_req: Request<Incoming>) -> Response<Full<Bytes>> {
Response::builder()
async fn handle_jemalloc_pprof_heapgen(
_req: Request<Incoming>,
) -> anyhow::Result<Response<Full<Bytes>>> {
Ok(Response::builder()
.status(hyper::StatusCode::INTERNAL_SERVER_ERROR)
.body("jemalloc not enabled".into())
.unwrap()
.expect("builder with known status code should not fail"))
}

fn base64_encode(data: String) -> String {
Expand Down Expand Up @@ -732,7 +719,7 @@ mod tests {
//
// this could happen for a variety of reasons; for example some types
// may need custom serialize/deserialize to be keys in a map, like NetworkAddress
let resp = handle_config_dump(dump).await;
let resp = handle_config_dump(dump).await.unwrap();

let resp_bytes = resp
.body()
Expand Down
33 changes: 21 additions & 12 deletions src/hyper_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use std::task::{Context, Poll};
Expand All @@ -23,6 +24,7 @@ use std::{

use bytes::Bytes;
use drain::Watch;
use futures_util::TryFutureExt;
use http_body_util::Full;
use hyper::client;
use hyper::rt::Sleep;
Expand Down Expand Up @@ -205,7 +207,7 @@ impl<S> Server<S> {
where
S: Send + Sync + 'static,
F: Fn(Arc<S>, Request<hyper::body::Incoming>) -> R + Send + Sync + 'static,
R: Future<Output = Result<Response<Full<Bytes>>, hyper::Error>> + Send + Sync + 'static,
R: Future<Output = Result<Response<Full<Bytes>>, anyhow::Error>> + Send + Sync + 'static,
{
use futures_util::StreamExt as OtherStreamExt;
let address = self.address();
Expand All @@ -229,18 +231,25 @@ impl<S> Server<S> {
let f = f.clone();
let state = state.clone();
tokio::spawn(async move {
let serve = http1_server()
.half_close(true)
.header_read_timeout(Duration::from_secs(2))
.max_buf_size(8 * 1024)
.serve_connection(
hyper_util::rt::TokioIo::new(socket),
hyper::service::service_fn(move |req| {
let state = state.clone();
let serve =
http1_server()
.half_close(true)
.header_read_timeout(Duration::from_secs(2))
.max_buf_size(8 * 1024)
.serve_connection(
hyper_util::rt::TokioIo::new(socket),
hyper::service::service_fn(move |req| {
let state = state.clone();

f(state, req)
}),
);
// Failures would abort the whole connection; we just want to return an HTTP error
f(state, req).or_else(|err| async move {
Ok::<Response<Full<Bytes>>, Infallible>(Response::builder()
.status(hyper::StatusCode::INTERNAL_SERVER_ERROR)
.body(err.to_string().into())
.expect("builder with known status code should not fail"))
})
}),
);
// Wait for drain to signal or connection serving to complete
match futures_util::future::select(Box::pin(drain.signaled()), serve).await {
// We got a shutdown request. Start gracful shutdown and wait for the pending requests to complete.
Expand Down
4 changes: 2 additions & 2 deletions src/identity/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ impl Worker {
},
// Initiate the next fetch.
true = maybe_sleep_until(next), if fetches.len() < self.concurrency as usize => {
let (id, _) = pending.pop().unwrap();
let (id, _) = pending.pop().expect("pending should always have an element at this point");
processing.insert(id.to_owned(), Fetch::Processing);
fetches.push(async move {
let res = self.client.fetch_certificate(&id).await;
Expand Down Expand Up @@ -456,7 +456,7 @@ impl fmt::Debug for SecretManager {
impl SecretManager {
pub async fn new(cfg: crate::config::Config) -> Result<Self, Error> {
let caclient = CaClient::new(
cfg.ca_address.unwrap(),
cfg.ca_address.expect("ca_address must be set to use CA"),
Box::new(tls::ControlPlaneAuthentication::RootCert(
cfg.ca_root_cert.clone(),
)),
Expand Down
11 changes: 8 additions & 3 deletions src/metrics/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,13 @@ async fn handle_metrics(
_req: Request<Incoming>,
) -> Response<Full<Bytes>> {
let mut buf = String::new();
let reg = reg.lock().unwrap();
encode(&mut buf, &reg).unwrap();
let reg = reg.lock().expect("mutex");
if let Err(err) = encode(&mut buf, &reg) {
return Response::builder()
.status(hyper::StatusCode::INTERNAL_SERVER_ERROR)
.body(err.to_string().into())
.expect("builder with known status code should not fail");
}

Response::builder()
.status(hyper::StatusCode::OK)
Expand All @@ -71,5 +76,5 @@ async fn handle_metrics(
"application/openmetrics-text;charset=utf-8;version=1.0.0",
)
.body(buf.into())
.unwrap()
.expect("builder with known status code should not fail")
}
18 changes: 9 additions & 9 deletions src/proxy/inbound.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ impl Inbound {
// Override with our explicitly configured setting
pi.cfg.enable_original_source = Some(transparent);
info!(
address=%listener.local_addr().unwrap(),
address=%listener.local_addr().expect("local_addr available"),
component="inbound",
transparent,
"listener established",
Expand All @@ -79,7 +79,7 @@ impl Inbound {
}

pub(super) fn address(&self) -> SocketAddr {
self.listener.local_addr().unwrap()
self.listener.local_addr().expect("local_addr available")
}

pub(super) async fn run(self) {
Expand All @@ -97,7 +97,7 @@ impl Inbound {
let (raw_socket, ssl) = tls.get_ref();
let src_identity: Option<Identity> = tls::identity_from_connection(ssl);
let dst = crate::socket::orig_dst_addr_or_default(raw_socket);
let src = to_canonical(raw_socket.peer_addr().unwrap());
let src = to_canonical(raw_socket.peer_addr().expect("peer_addr available"));
let pi = self.pi.clone();
let connection_manager = self.pi.connection_manager.clone();
let drain = sub_drain.clone();
Expand Down Expand Up @@ -332,7 +332,7 @@ impl Inbound {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Empty::new())
.unwrap());
.expect("builder with known status code"));
}
};

Expand All @@ -345,7 +345,7 @@ impl Inbound {
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Empty::new())
.unwrap());
.expect("builder with known status code"));
}
};

Expand Down Expand Up @@ -388,7 +388,7 @@ impl Inbound {
return Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Empty::new())
.unwrap());
.expect("builder with known status code should not fail"));
}
// This check should be removed in favor of an L4 policy check
// We should express as policy whether or not traffic is allowed to bypass a waypoint
Expand All @@ -398,7 +398,7 @@ impl Inbound {
return Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(Empty::new())
.unwrap());
.expect("builder with known status code should not fail"));
}
let source_ip = if from_waypoint {
// If the request is from our waypoint, trust the Forwarded header.
Expand Down Expand Up @@ -474,15 +474,15 @@ impl Inbound {
Ok(Response::builder()
.status(status_code)
.body(Empty::new())
.unwrap())
.expect("builder with known status code should not fail"))
}
// Return the 404 Not Found for other routes.
method => {
info!("Sending 404, got {method}");
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(Empty::new())
.unwrap())
.expect("builder with known status code should not fail"))
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/proxy/inbound_passthrough.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl InboundPassthrough {
pi.cfg.enable_original_source = Some(transparent);

info!(
address=%listener.local_addr().unwrap(),
address=%listener.local_addr().expect("local_addr available"),
component="inbound plaintext",
transparent,
"listener established",
Expand Down
Loading

0 comments on commit 961eb56

Please sign in to comment.