Skip to content

Commit

Permalink
wasm: Return informative errors from leb128 parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
robinlinden committed Oct 31, 2023
1 parent 7f08dae commit a7fd50d
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 32 deletions.
28 changes: 18 additions & 10 deletions wasm/leb128.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#ifndef WASM_LEB128_H_
#define WASM_LEB128_H_

#include <tl/expected.hpp>

#include <cassert>
#include <cmath>
#include <concepts>
Expand All @@ -14,6 +16,12 @@

namespace wasm {

enum class Leb128ParseError {
Invalid,
NonZeroExtraBits,
UnexpectedEof,
};

// https://webassembly.github.io/spec/core/binary/values.html#integers
template<typename T>
requires std::integral<T>
Expand All @@ -22,15 +30,15 @@ struct Leb128 {};
// https://en.wikipedia.org/wiki/LEB128#Decode_unsigned_integer
template<std::unsigned_integral T>
struct Leb128<T> {
static std::optional<T> decode_from(std::istream &&is) { return decode_from(is); }
static std::optional<T> decode_from(std::istream &is) {
static tl::expected<T, Leb128ParseError> decode_from(std::istream &&is) { return decode_from(is); }
static tl::expected<T, Leb128ParseError> decode_from(std::istream &is) {
T result{};
std::uint8_t shift{};
auto const max_bytes = static_cast<int>(std::ceil(sizeof(T) * 8 / 7.f));
for (int i = 0; i < max_bytes; ++i) {
std::uint8_t byte{};
if (!is.read(reinterpret_cast<char *>(&byte), sizeof(byte))) {
return std::nullopt;
return tl::unexpected{Leb128ParseError::UnexpectedEof};
}

if (i == max_bytes - 1) {
Expand All @@ -40,7 +48,7 @@ struct Leb128<T> {
auto extra_bits_mask = (0xff << remaining_value_bits) & 0b0111'1111;
auto extra_bits = byte & extra_bits_mask;
if (extra_bits != 0) {
return std::nullopt;
return tl::unexpected{Leb128ParseError::NonZeroExtraBits};
}
}

Expand All @@ -52,7 +60,7 @@ struct Leb128<T> {
shift += 7;
}

return std::nullopt;
return tl::unexpected{Leb128ParseError::Invalid};
}
};

Expand All @@ -63,15 +71,15 @@ struct Leb128<T> {
static constexpr std::uint8_t kNonContinuationBits = 0b0111'1111;
static constexpr std::uint8_t kSignBit = 0b0100'0000;

static std::optional<T> decode_from(std::istream &&is) { return decode_from(is); }
static std::optional<T> decode_from(std::istream &is) {
static tl::expected<T, Leb128ParseError> decode_from(std::istream &&is) { return decode_from(is); }
static tl::expected<T, Leb128ParseError> decode_from(std::istream &is) {
T result{};
std::uint8_t shift{};
std::uint8_t byte{};
auto const max_bytes = static_cast<int>(std::ceil(sizeof(T) * 8 / 7.f));
for (int i = 0; i < max_bytes; ++i) {
if (!is.read(reinterpret_cast<char *>(&byte), sizeof(byte))) {
return std::nullopt;
return tl::unexpected{Leb128ParseError::UnexpectedEof};
}

if (i == max_bytes - 1) {
Expand All @@ -81,7 +89,7 @@ struct Leb128<T> {
auto extra_bits_mask = (0xff << remaining_value_bits) & kNonContinuationBits;
auto extra_bits = byte & extra_bits_mask;
if (extra_bits != 0 && extra_bits != extra_bits_mask) {
return std::nullopt;
return tl::unexpected{Leb128ParseError::NonZeroExtraBits};
}
}

Expand All @@ -93,7 +101,7 @@ struct Leb128<T> {
}

if (byte & kContinuationBit) {
return std::nullopt;
return tl::unexpected{Leb128ParseError::Invalid};
}

if ((shift < sizeof(T) * 8) && (byte & kSignBit)) {
Expand Down
52 changes: 31 additions & 21 deletions wasm/leb128_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using namespace std::literals;
using etest::expect_eq;
using wasm::Leb128;
using wasm::Leb128ParseError;

namespace {
template<typename T>
Expand All @@ -25,8 +26,12 @@ void expect_decoded(std::string bytes, T expected, etest::source_location loc =
};

template<typename T>
void expect_decode_failure(std::string bytes, etest::source_location loc = etest::source_location::current()) {
expect_eq(Leb128<T>::decode_from(std::stringstream{std::move(bytes)}), std::nullopt, std::nullopt, std::move(loc));
void expect_decode_failure(
std::string bytes, Leb128ParseError error, etest::source_location loc = etest::source_location::current()) {
expect_eq(Leb128<T>::decode_from(std::stringstream{std::move(bytes)}),
tl::unexpected{error},
std::nullopt,
std::move(loc));
};
} // namespace

Expand All @@ -38,9 +43,9 @@ int main() {
expect_decoded<std::uint32_t>("\x80\x7f", 16256);

// Missing termination.
expect_decode_failure<std::uint32_t>("\x80");
expect_decode_failure<std::uint32_t>("\x80", Leb128ParseError::UnexpectedEof);
// Too many bytes with no termination.
expect_decode_failure<std::uint32_t>("\x80\x80\x80\x80\x80\x80");
expect_decode_failure<std::uint32_t>("\x80\x80\x80\x80\x80\x80", Leb128ParseError::Invalid);

// https://github.com/llvm/llvm-project/blob/34aff47521c3e0cbac58b0d5793197f76a304295/llvm/unittests/Support/LEB128Test.cpp#L119-L142
expect_decoded<std::uint32_t>("\0"s, 0);
Expand All @@ -67,12 +72,13 @@ int main() {

// https://github.com/llvm/llvm-project/blob/34aff47521c3e0cbac58b0d5793197f76a304295/llvm/unittests/Support/LEB128Test.cpp#L160-L166
// Buffer overflow.
expect_decode_failure<std::uint64_t>("");
expect_decode_failure<std::uint64_t>("\x80");
expect_decode_failure<std::uint64_t>("", Leb128ParseError::UnexpectedEof);
expect_decode_failure<std::uint64_t>("\x80", Leb128ParseError::UnexpectedEof);

// Does not fit in 64 bits.
expect_decode_failure<std::uint64_t>("\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02");
expect_decode_failure<std::uint64_t>("\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02");
expect_decode_failure<std::uint64_t>(
"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02", Leb128ParseError::NonZeroExtraBits);
expect_decode_failure<std::uint64_t>("\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02", Leb128ParseError::Invalid);
});

etest::test("trailing zeros", [] {
Expand Down Expand Up @@ -102,11 +108,11 @@ int main() {
// 1 for negative ones.

// For example, 0x83 0x10 is malformed as a u8 encoding.
expect_decode_failure<std::uint8_t>("\x83\x10");
expect_decode_failure<std::uint8_t>("\x83\x10", Leb128ParseError::NonZeroExtraBits);

// Similarly, both 0x83 0x3E and 0xFF 0x7B are malformed as s8 encodings
expect_decode_failure<std::int8_t>("\x83\x3e");
expect_decode_failure<std::int8_t>("\xff\x7b");
expect_decode_failure<std::int8_t>("\x83\x3e", Leb128ParseError::NonZeroExtraBits);
expect_decode_failure<std::int8_t>("\xff\x7b", Leb128ParseError::NonZeroExtraBits);
});

etest::test("decode signed", [&] {
Expand Down Expand Up @@ -142,16 +148,20 @@ int main() {
expect_decoded<std::int64_t>("\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"s, kInt64Max);

// https://github.com/llvm/llvm-project/blob/34aff47521c3e0cbac58b0d5793197f76a304295/llvm/unittests/Support/LEB128Test.cpp#L229-L240
expect_decode_failure<std::int8_t>("");
expect_decode_failure<std::int8_t>("\x80");

expect_decode_failure<std::int64_t>("\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01");
expect_decode_failure<std::int64_t>("\x80\x80\x80\x80\x80\x80\x80\x80\x80\x7e");
expect_decode_failure<std::int64_t>("\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02");
expect_decode_failure<std::int64_t>("\xff\xff\xff\xff\xff\xff\xff\xff\xff\x7e");
expect_decode_failure<std::int64_t>("\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01");
expect_decode_failure<std::int64_t>("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x7e");
expect_decode_failure<std::int64_t>("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"s);
expect_decode_failure<std::int8_t>("", Leb128ParseError::UnexpectedEof);
expect_decode_failure<std::int8_t>("\x80", Leb128ParseError::UnexpectedEof);

expect_decode_failure<std::int64_t>(
"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01", Leb128ParseError::NonZeroExtraBits);
expect_decode_failure<std::int64_t>(
"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x7e", Leb128ParseError::NonZeroExtraBits);
expect_decode_failure<std::int64_t>("\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02", Leb128ParseError::Invalid);
expect_decode_failure<std::int64_t>(
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x7e", Leb128ParseError::NonZeroExtraBits);
expect_decode_failure<std::int64_t>(
"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01", Leb128ParseError::NonZeroExtraBits);
expect_decode_failure<std::int64_t>("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x7e", Leb128ParseError::Invalid);
expect_decode_failure<std::int64_t>("\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00"s, Leb128ParseError::Invalid);
});
// NOLINTEND(modernize-raw-string-literal)

Expand Down
3 changes: 2 additions & 1 deletion wasm/wasm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ std::optional<T> parse(std::istream &) = delete;

template<>
std::optional<std::uint32_t> parse(std::istream &is) {
return Leb128<std::uint32_t>::decode_from(is);
auto v = Leb128<std::uint32_t>::decode_from(is);
return v ? std::optional{*v} : std::nullopt;
}

template<>
Expand Down

0 comments on commit a7fd50d

Please sign in to comment.