From 98b69ff39c33bf3ca889c81288df0c99c94ace7f Mon Sep 17 00:00:00 2001 From: Matt Gathu Date: Fri, 29 Dec 2023 11:14:15 +0100 Subject: [PATCH] Simplify byte ranges --- src/errors.rs | 5 ++ src/http.rs | 60 +++++++++++------------- src/server.rs | 125 +++++++++++++++++++++++++++++++------------------- 3 files changed, 108 insertions(+), 82 deletions(-) diff --git a/src/errors.rs b/src/errors.rs index 68da07b..78d42fe 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -3,6 +3,8 @@ use std::{io, num::ParseIntError, string::FromUtf8Error, time::SystemTimeError}; use thiserror::Error; +use crate::http::HeaderName; + pub type Result = std::result::Result; #[derive(Error, Debug)] @@ -26,6 +28,8 @@ pub enum SevaError { TestClient(String), #[error("URI Too Long")] UriTooLong, + #[error("Missing value for header: {0}")] + MissingHeaderValue(HeaderName), } #[derive(Error, Debug)] @@ -37,6 +41,7 @@ pub enum ParsingError { PestRuleError(String), DateTime(String), IntError(#[from] ParseIntError), + InvalidRangeHeader(String), } impl fmt::Display for ParsingError { diff --git a/src/http.rs b/src/http.rs index 353fede..1b82f01 100644 --- a/src/http.rs +++ b/src/http.rs @@ -316,7 +316,7 @@ len = { ASCII_DIGIT* } pub struct HttpParser; impl HttpParser { - pub fn parse_bytes_range(val: &str) -> Result> { + pub fn parse_bytes_range(val: &str, max_len: usize) -> Result> { let br = HttpParser::parse(Rule::bytes_range, val) .map_err(|e| ParsingError::PestRuleError(format!("{e:?}")))? .next() @@ -326,61 +326,53 @@ impl HttpParser { match pair.as_rule() { Rule::int_range => { let mut inner = pair.into_inner(); - let first_pos = inner + let start = inner .next() .unwrap() .as_str() .parse() .map_err(ParsingError::IntError)?; - let last_pos = match inner.next() { - Some(r) => Some( - r.as_str().parse().map_err(ParsingError::IntError)?, - ), - None => None, + let end = match inner.next() { + Some(r) => { + r.as_str().parse().map_err(ParsingError::IntError)? + } + None => max_len, }; - ranges.push(BytesRange::Int { - start: first_pos, - end: last_pos, - }); + if start > end { + Err(ParsingError::InvalidRangeHeader(val.to_owned()))?; + } + let size = end - start; + ranges.push(BytesRange { start, size }); } Rule::suffix_range => { let mut inner = pair.into_inner(); - let len = inner + let size = inner .next() .unwrap() .as_str() .parse() .map_err(ParsingError::IntError)?; - ranges.push(BytesRange::Suffix { len }); + if size >= max_len { + Err(ParsingError::InvalidRangeHeader(val.to_owned()))?; + } + let start = max_len - size; + ranges.push(BytesRange { start, size }); } _ => {} } } + if ranges.len() > 10 { + return Err(ParsingError::InvalidRangeHeader(val.to_owned()))?; + } Ok(ranges) } } #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] -pub enum BytesRange { - Int { start: usize, end: Option }, - Suffix { len: usize }, -} - -impl BytesRange { - pub fn is_valid(&self, max_len: usize) -> bool { - match *self { - BytesRange::Int { start, end } => { - let end = end.unwrap_or(max_len); - if start > end { - false - } else { - end <= max_len - } - } - BytesRange::Suffix { len } => len > max_len, - } - } +pub struct BytesRange { + pub start: usize, + pub size: usize, } macro_rules! status_codes { @@ -695,8 +687,8 @@ mod tests { "bytes=500-600,601-999", "bytes=500-700,601-999", ] { - let range = HttpParser::parse_bytes_range(val); - assert!(range.is_ok(), "failed to parse: {val}"); + let range = HttpParser::parse_bytes_range(val, 10000); + assert!(range.is_ok(), "failed to parse: {val}. Reason: {range:?}"); } Ok(()) } diff --git a/src/server.rs b/src/server.rs index eaf0456..2f40e1e 100644 --- a/src/server.rs +++ b/src/server.rs @@ -14,11 +14,12 @@ use std::{ use bytes::{BufMut, BytesMut}; use chrono::Local; use clap::crate_version; +use contracts::debug_requires; use handlebars::Handlebars; use tracing::{debug, error, info, trace, warn}; use crate::{ - errors::{IoErrorUtils, Result, SevaError}, + errors::{IoErrorUtils, ParsingError, Result, SevaError}, fs::{DirEntry, EntryType}, http::{ HeaderName, HttpMethod, HttpParser, Request, Response, ResponseBuilder, @@ -146,6 +147,12 @@ impl RequestHandler { SevaError::UriTooLong => { self.send_error(StatusCode::UriTooLong, None)? } + SevaError::ParsingError(ParsingError::InvalidRangeHeader(hdr)) => { + self.send_error( + StatusCode::RangeNotSatisifiable, + Some(&format!("invalid range: {hdr}")), + )? + } _ => { error!("internal server error: {e}"); self.send_error( @@ -243,58 +250,46 @@ impl RequestHandler { Ok(()) } + #[debug_requires(request.headers.contains_key(&HeaderName::Range))] fn send_partial(&mut self, request: &Request, entry: &DirEntry) -> Result<()> { trace!("RequestHandler::send_partial"); - if let Some(val) = request.headers.get(&HeaderName::Range) { - let ranges = HttpParser::parse_bytes_range(val)?; - // we only serve the first range - if let Some(range) = ranges.into_iter().next() { - if !range.is_valid(entry.size as usize) { - return self.send_error( - StatusCode::RangeNotSatisifiable, - Some("invalid bytes range"), - ); - } - let (start, size) = match range { - crate::http::BytesRange::Int { start, end } => { - (start, end.unwrap_or(entry.size as usize) - start) - } - crate::http::BytesRange::Suffix { len } => { - let start = entry.size as usize - len; - (start, len) - } - }; - let mut file = File::open(&entry.name)?; - file.seek(io::SeekFrom::Start(start as u64))?; - let mut buf = vec![0u8; size]; - file.read_exact(&mut buf)?; - let response = ResponseBuilder::partial() - .headers(self.get_file_headers(entry)) - .header(HeaderName::ContentLength, &format!("{}", buf.len())) - .header( - HeaderName::ContentRange, - &format!("bytes {}-{}/{}", start, start + size, entry.size), - ) - .header(HeaderName::Vary, "*") - .body(Cursor::new(buf)) - .build(); - self.send_response(response, request)?; - } else { - warn!( - "RequestHandler::send_partial unreachable block: empty ranges" - ); - self.send_error( - StatusCode::RangeNotSatisifiable, - Some("invalid Range header"), - )?; - }; + let val = request + .headers + .get(&HeaderName::Range) + .ok_or_else(|| SevaError::MissingHeaderValue(HeaderName::Range))?; + let ranges = HttpParser::parse_bytes_range(val, entry.size as usize) + .map_err(|_| ParsingError::InvalidRangeHeader(val.to_string()))?; + + // we only serve the first range + if let Some(range) = ranges.into_iter().next() { + let mut file = File::open(&entry.name)?; + file.seek(io::SeekFrom::Start(range.start as u64))?; + let mut buf = vec![0u8; range.size]; + file.read_exact(&mut buf)?; + let response = ResponseBuilder::partial() + .headers(self.get_file_headers(entry)) + .header(HeaderName::ContentLength, &format!("{}", buf.len())) + .header( + HeaderName::ContentRange, + &format!( + "bytes {}-{}/{}", + range.start, + range.start + range.size, + entry.size + ), + ) + .header(HeaderName::Vary, "*") + .body(Cursor::new(buf)) + .build(); + self.send_response(response, request)?; } else { - error!("RequestHandler::send_partial unreachable block: missing range header value"); + warn!("RequestHandler::send_partial unreachable block: empty ranges"); self.send_error( - StatusCode::InternalServerError, - Some("missing range header value"), + StatusCode::RangeNotSatisifiable, + Some("invalid Range header"), )?; - } + }; + Ok(()) } @@ -734,6 +729,40 @@ mod tests { Ok(()) } + #[test] + fn invalid_range_unit() -> Result<()> { + let port = start_server()?; + + let client = reqwest::blocking::Client::new(); + let response = client + .get(format!("http://127.0.0.1:{port}/Cargo.toml")) + .header("Range", "bits=0-500") + .send() + .map_err(|e| SevaError::TestClient(format!("{}", e)))?; + + assert!(response.status().is_client_error()); + assert_eq!(response.status().as_u16(), 416); + + Ok(()) + } + + #[test] + fn empty_bytes_range() -> Result<()> { + let port = start_server()?; + + let client = reqwest::blocking::Client::new(); + let response = client + .get(format!("http://127.0.0.1:{port}/Cargo.toml")) + .header("Range", "bytes= ") + .send() + .map_err(|e| SevaError::TestClient(format!("{}", e)))?; + + assert!(response.status().is_client_error()); + assert_eq!(response.status().as_u16(), 416); + + Ok(()) + } + #[test] fn mime_type_works() -> Result<()> { let port = start_server()?;