Skip to content

Commit

Permalink
Merge pull request RustPython#4460 from evanrittenhouse/strict_mode_new
Browse files Browse the repository at this point in the history
Implement `strict_mode` keyword for `binascii.a2b_base64()`
  • Loading branch information
youknowone authored Mar 4, 2023
2 parents 1f92212 + ff973ca commit 4e19be7
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 37 deletions.
6 changes: 0 additions & 6 deletions Lib/test/test_binascii.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ def test_base64valid(self):
res += b
self.assertEqual(res, self.rawdata)

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_base64invalid(self):
# Test base64 with random invalid characters sprinkled throughout
# (This requires a new version of binascii.)
Expand Down Expand Up @@ -114,8 +112,6 @@ def addnoise(line):
# empty strings. TBD: shouldn't it raise an exception instead ?
self.assertEqual(binascii.a2b_base64(self.type2test(fillers)), b'')

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_base64_strict_mode(self):
# Test base64 with strict mode on
def _assertRegexTemplate(assert_regex: str, data: bytes, non_strict_mode_expected_result: bytes):
Expand Down Expand Up @@ -159,8 +155,6 @@ def assertDiscontinuousPadding(data, non_strict_mode_expected_result: bytes):
assertDiscontinuousPadding(b'ab=c=', b'i\xb7')
assertDiscontinuousPadding(b'ab=ab==', b'i\xb6\x9b')

# TODO: RUSTPYTHON
@unittest.expectedFailure
def test_base64errors(self):
# Test base64 with invalid padding
def assertIncorrectPadding(data):
Expand Down
160 changes: 129 additions & 31 deletions stdlib/src/binascii.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
pub(super) use decl::crc32;
pub(crate) use decl::make_module;
use rustpython_vm::{builtins::PyBaseExceptionRef, convert::ToPyException, VirtualMachine};

pub(super) use decl::crc32;

pub fn decode<T: AsRef<[u8]>>(input: T) -> Result<Vec<u8>, base64::DecodeError> {
base64::decode_config(input, base64::STANDARD.decode_allow_trailing_bits(true))
}
const PAD: u8 = 61u8;
const MAXLINESIZE: usize = 76; // Excluding the CRLF

#[pymodule(name = "binascii")]
mod decl {
use super::decode;
use super::{MAXLINESIZE, PAD};
use crate::vm::{
builtins::{PyBaseExceptionRef, PyIntRef, PyTypeRef},
builtins::{PyIntRef, PyTypeRef},
convert::ToPyException,
function::{ArgAsciiBuffer, ArgBytesLike, OptionalArg},
PyResult, VirtualMachine,
};
use itertools::Itertools;

const MAXLINESIZE: usize = 76;

#[pyattr(name = "Error", once)]
fn error_type(vm: &VirtualMachine) -> PyTypeRef {
pub(super) fn error_type(vm: &VirtualMachine) -> PyTypeRef {
vm.ctx.new_exception_type(
"binascii",
"Error",
Expand Down Expand Up @@ -67,15 +65,18 @@ mod decl {
fn unhexlify(data: ArgAsciiBuffer, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
data.with_ref(|hex_bytes| {
if hex_bytes.len() % 2 != 0 {
return Err(new_binascii_error("Odd-length string".to_owned(), vm));
return Err(super::new_binascii_error(
"Odd-length string".to_owned(),
vm,
));
}

let mut unhex = Vec::<u8>::with_capacity(hex_bytes.len() / 2);
for (n1, n2) in hex_bytes.iter().tuples() {
if let (Some(n1), Some(n2)) = (unhex_nibble(*n1), unhex_nibble(*n2)) {
unhex.push(n1 << 4 | n2);
} else {
return Err(new_binascii_error(
return Err(super::new_binascii_error(
"Non-hexadecimal digit found".to_owned(),
vm,
));
Expand Down Expand Up @@ -144,13 +145,20 @@ mod decl {
newline: bool,
}

fn new_binascii_error(msg: String, vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_exception_msg(error_type(vm), msg)
#[derive(FromArgs)]
struct A2bBase64Args {
#[pyarg(any)]
s: ArgAsciiBuffer,
#[pyarg(named, default = "false")]
strict_mode: bool,
}

#[pyfunction]
fn a2b_base64(s: ArgAsciiBuffer, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
fn a2b_base64(args: A2bBase64Args, vm: &VirtualMachine) -> PyResult<Vec<u8>> {
#[rustfmt::skip]
// Converts between ASCII and base-64 characters. The index of a given number yields the
// number in ASCII while the value of said index yields the number in base-64. For example
// "=" is 61 in ASCII but 0 (since it's the pad character) in base-64, so BASE64_TABLE[61] == 0
const BASE64_TABLE: [i8; 256] = [
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
Expand All @@ -171,25 +179,92 @@ mod decl {
-1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1, -1,-1,-1,-1,
];

let A2bBase64Args { s, strict_mode } = args;
s.with_ref(|b| {
let decoded = if b.len() % 4 == 0 {
decode(b)
} else {
Err(base64::DecodeError::InvalidLength)
};
decoded.or_else(|_| {
let buf: Vec<_> = b
.iter()
.copied()
.filter(|&c| BASE64_TABLE[c as usize] != -1)
.collect();
if buf.len() % 4 != 0 {
return Err(base64::DecodeError::InvalidLength);
if b.is_empty() {
return Ok(vec![]);
}

if strict_mode && b[0] == PAD {
return Err(base64::DecodeError::InvalidByte(0, 61));
}

let mut decoded: Vec<u8> = vec![];

let mut quad_pos = 0; // position in the nibble
let mut pads = 0;
let mut left_char: u8 = 0;
let mut padding_started = false;
for (i, &el) in b.iter().enumerate() {
if el == PAD {
padding_started = true;

pads += 1;
if quad_pos >= 2 && quad_pos + pads >= 4 {
if strict_mode && i + 1 < b.len() {
// Represents excess data after padding error
return Err(base64::DecodeError::InvalidLastSymbol(i, PAD));
}

return Ok(decoded);
}

continue;
}
decode(&buf)
})

let binary_char = BASE64_TABLE[el as usize];
if binary_char >= 64 || binary_char == -1 {
if strict_mode {
// Represents non-base64 data error
return Err(base64::DecodeError::InvalidByte(i, el));
}
continue;
}

if strict_mode && padding_started {
// Represents discontinuous padding error
return Err(base64::DecodeError::InvalidByte(i, PAD));
}
pads = 0;

// Decode individual ASCII character
match quad_pos {
0 => {
quad_pos = 1;
left_char = binary_char as u8;
}
1 => {
quad_pos = 2;
decoded.push((left_char << 2) | (binary_char >> 4) as u8);
left_char = (binary_char & 0x0f) as u8;
}
2 => {
quad_pos = 3;
decoded.push((left_char << 4) | (binary_char >> 2) as u8);
left_char = (binary_char & 0x03) as u8;
}
3 => {
quad_pos = 0;
decoded.push((left_char << 6) | binary_char as u8);
left_char = 0;
}
_ => unsafe {
// quad_pos is only assigned in this match statement to constants
std::hint::unreachable_unchecked()
},
}
}

match quad_pos {
0 => Ok(decoded),
1 => Err(base64::DecodeError::InvalidLastSymbol(
decoded.len() / 3 * 4 + 1,
0,
)),
_ => Err(base64::DecodeError::InvalidLength),
}
})
.map_err(|err| new_binascii_error(format!("error decoding base64: {err}"), vm))
.map_err(|err| super::Base64DecodeError(err).to_pyexception(vm))
}

#[pyfunction]
Expand Down Expand Up @@ -654,3 +729,26 @@ mod decl {
})
}
}

struct Base64DecodeError(base64::DecodeError);

fn new_binascii_error(msg: String, vm: &VirtualMachine) -> PyBaseExceptionRef {
vm.new_exception_msg(decl::error_type(vm), msg)
}

impl ToPyException for Base64DecodeError {
fn to_pyexception(&self, vm: &VirtualMachine) -> PyBaseExceptionRef {
use base64::DecodeError::*;
let message = match self.0 {
InvalidByte(0, PAD) => "Leading padding not allowed".to_owned(),
InvalidByte(_, PAD) => "Discontinuous padding not allowed".to_owned(),
InvalidByte(_, _) => "Only base64 data is allowed".to_owned(),
InvalidLastSymbol(_, PAD) => "Excess data after padding".to_owned(),
InvalidLastSymbol(length, _) => {
format!("Invalid base64-encoded string: number of data characters {} cannot be 1 more than a multiple of 4", length)
}
InvalidLength => "Incorrect padding".to_owned(),
};
new_binascii_error(format!("error decoding base64: {message}"), vm)
}
}

0 comments on commit 4e19be7

Please sign in to comment.