Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

node: fix buffer includes+indexof #16642

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
231 changes: 144 additions & 87 deletions src/bun.js/bindings/JSBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,37 @@ extern "C" void* zig_memmem(const void* haystack, size_t haystack_len, const voi
#define MEMMEM_IMPL memmem
#endif

static ssize_t indexOfOffset(size_t length, ssize_t offset_i64, ssize_t needle_length, bool is_forward)
{
ssize_t length_i64 = static_cast<ssize_t>(length);
if (offset_i64 < 0) {
if (offset_i64 + length_i64 >= 0) {
// Negative offsets count backwards from the end of the buffer.
return length_i64 + offset_i64;
} else if (is_forward || needle_length == 0) {
// indexOf from before the start of the buffer: search the whole buffer.
return 0;
} else {
// lastIndexOf from before the start of the buffer: no match.
return -1;
}
} else {
if (offset_i64 + needle_length <= length_i64) {
// Valid positive offset.
return offset_i64;
} else if (needle_length == 0) {
// Out of buffer bounds, but empty needle: point to end of buffer.
return length_i64;
} else if (is_forward) {
// indexOf from past the end of the buffer: no match.
return -1;
} else {
// lastIndexOf from past the end of the buffer: search the whole buffer.
return length_i64 - 1;
}
}
}

static int64_t indexOf(const uint8_t* thisPtr, int64_t thisLength, const uint8_t* valuePtr, int64_t valueLength, int64_t byteOffset)
{
if (thisLength < valueLength + byteOffset)
Expand All @@ -1335,6 +1366,28 @@ static int64_t indexOf(const uint8_t* thisPtr, int64_t thisLength, const uint8_t
return -1;
}

static int64_t indexOf16(const uint8_t* thisPtr, int64_t thisLength, const uint8_t* valuePtr, int64_t valueLength, int64_t byteOffset)
{
size_t finalresult = 0;
nektro marked this conversation as resolved.
Show resolved Hide resolved
if (thisLength == 1) return -1;
thisLength = thisLength / 2 * 2;
if (valueLength == 1) return -1;
valueLength = valueLength / 2 * 2;
byteOffset = byteOffset / 2 * 2;
while (true) {
auto res = indexOf(thisPtr, thisLength, valuePtr, valueLength, byteOffset);
if (res == -1) return -1;
if (res % 2 == 1) {
thisPtr += res + 1;
thisLength -= res + 1;
finalresult += res + 1;
continue;
}
finalresult += res;
return finalresult;
}
}

static int64_t lastIndexOf(const uint8_t* thisPtr, int64_t thisLength, const uint8_t* valuePtr, int64_t valueLength, int64_t byteOffset)
{
auto start = thisPtr;
Expand All @@ -1346,108 +1399,112 @@ static int64_t lastIndexOf(const uint8_t* thisPtr, int64_t thisLength, const uin
return -1;
}

static int64_t indexOf(JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame, typename IDLOperation<JSArrayBufferView>::ClassParameter castedThis, bool last)
static int64_t indexOfNumber(JSC::JSGlobalObject* lexicalGlobalObject, bool last, const uint8_t* typedVector, size_t byteLength, double byteOffsetD, uint8_t byteValue)
{
auto& vm = JSC::getVM(lexicalGlobalObject);
auto scope = DECLARE_THROW_SCOPE(vm);
if (callFrame->argumentCount() < 1) {
throwVMError(lexicalGlobalObject, scope, createNotEnoughArgumentsError(lexicalGlobalObject));
return -1;
ssize_t byteOffset = indexOfOffset(byteLength, byteOffsetD, 1, !last);
if (byteOffset == -1) return -1;
if (last) {
for (int64_t i = byteOffset; i >= 0; --i) {
nektro marked this conversation as resolved.
Show resolved Hide resolved
if (byteValue == typedVector[i]) return i;
}
} else {
const void* offset = memchr(reinterpret_cast<const void*>(typedVector + byteOffset), byteValue, byteLength - byteOffset);
if (offset != NULL) return static_cast<const uint8_t*>(offset) - typedVector;
}
return -1;
}

auto value = callFrame->uncheckedArgument(0);
WebCore::BufferEncodingType encoding = WebCore::BufferEncodingType::utf8;

int64_t length = static_cast<int64_t>(castedThis->byteLength());
const uint8_t* typedVector = castedThis->typedVector();
static int64_t indexOfString(JSC::JSGlobalObject* lexicalGlobalObject, bool last, const uint8_t* typedVector, size_t byteLength, double byteOffsetD, JSString* str, BufferEncodingType encoding)
{
ssize_t byteOffset = indexOfOffset(byteLength, byteOffsetD, str->length(), !last);
if (byteOffset == -1) return -1;
if (str->length() == 0) return byteOffset;
JSC::EncodedJSValue encodedBuffer = constructFromEncoding(lexicalGlobalObject, str, encoding);
auto* arrayValue = JSC::jsCast<JSC::JSUint8Array*>(JSC::JSValue::decode(encodedBuffer));
int64_t lengthValue = static_cast<int64_t>(arrayValue->byteLength());
const uint8_t* typedVectorValue = arrayValue->typedVector();
if (last) {
return lastIndexOf(typedVector, byteLength, typedVectorValue, lengthValue, byteOffset);
} else {
if (encoding == BufferEncodingType::ucs2) return indexOf16(typedVector, byteLength, typedVectorValue, lengthValue, byteOffset);
return indexOf(typedVector, byteLength, typedVectorValue, lengthValue, byteOffset);
}
}

int64_t byteOffset = last ? length - 1 : 0;
static int64_t indexOfBuffer(JSC::JSGlobalObject* lexicalGlobalObject, bool last, const uint8_t* typedVector, size_t byteLength, double byteOffsetD, JSC::JSGenericTypedArrayView<JSC::Uint8Adaptor>* array, BufferEncodingType encoding)
{
size_t lengthValue = array->byteLength();
ssize_t byteOffset = indexOfOffset(byteLength, byteOffsetD, lengthValue, !last);
if (byteOffset == -1) return -1;
if (lengthValue == 0) return byteOffset;
const uint8_t* typedVectorValue = array->typedVector();
if (last) {
return lastIndexOf(typedVector, byteLength, typedVectorValue, lengthValue, byteOffset);
} else {
if (encoding == BufferEncodingType::ucs2) return indexOf16(typedVector, byteLength, typedVectorValue, lengthValue, byteOffset);
return indexOf(typedVector, byteLength, typedVectorValue, lengthValue, byteOffset);
}
}

if (callFrame->argumentCount() > 1) {
EnsureStillAliveScope arg1 = callFrame->uncheckedArgument(1);
if (arg1.value().isString()) {
encoding = parseEncoding(lexicalGlobalObject, scope, arg1.value());
RETURN_IF_EXCEPTION(scope, -1);
} else {
auto byteOffset_ = arg1.value().toNumber(lexicalGlobalObject);
RETURN_IF_EXCEPTION(scope, -1);

if (std::isnan(byteOffset_) || std::isinf(byteOffset_)) {
byteOffset = last ? length - 1 : 0;
} else if (byteOffset_ < 0) {
byteOffset = length + static_cast<int64_t>(byteOffset_);
} else {
byteOffset = static_cast<int64_t>(byteOffset_);
}
static int64_t indexOf(JSC::JSGlobalObject* lexicalGlobalObject, JSC::CallFrame* callFrame, typename IDLOperation<JSArrayBufferView>::ClassParameter buffer, bool last)
{
auto& vm = lexicalGlobalObject->vm();
auto scope = DECLARE_THROW_SCOPE(vm);
bool dir = !last;
const uint8_t* typedVector = buffer->typedVector();
size_t byteLength = buffer->byteLength();
std::optional<BufferEncodingType> encoding = std::nullopt;
double byteOffsetD = 0;

if (byteLength == 0) return -1;

auto valueValue = callFrame->argument(0);
auto byteOffsetValue = callFrame->argument(1);
auto encodingValue = callFrame->argument(2);

if (byteOffsetValue.isString()) {
encodingValue = byteOffsetValue;
byteOffsetValue = jsUndefined();
byteOffsetD = 0;
} else {
byteOffsetD = byteOffsetValue.toNumber(lexicalGlobalObject);
RETURN_IF_EXCEPTION(scope, -1);
if (byteOffsetD > 0x7fffffffp0f) byteOffsetD = 0x7fffffffp0f;
if (byteOffsetD < -0x80000000p0f) byteOffsetD = -0x80000000p0f;
}

if (last) {
if (byteOffset < 0) {
return -1;
} else if (byteOffset > length - 1) {
byteOffset = length - 1;
}
} else {
if (byteOffset <= 0) {
byteOffset = 0;
} else if (byteOffset > length - 1) {
return -1;
}
}
if (std::isnan(byteOffsetD)) byteOffsetD = dir ? 0 : byteLength;

if (callFrame->argumentCount() > 2) {
EnsureStillAliveScope encodingValue = callFrame->uncheckedArgument(2);
if (!encodingValue.value().isUndefined()) {
encoding = parseEncoding(lexicalGlobalObject, scope, encodingValue.value());
RETURN_IF_EXCEPTION(scope, -1);
}
}
}
if (valueValue.isNumber()) {
uint8_t byteValue = (valueValue.toInt32(lexicalGlobalObject)) % 256;
nektro marked this conversation as resolved.
Show resolved Hide resolved
RETURN_IF_EXCEPTION(scope, -1);
return indexOfNumber(lexicalGlobalObject, last, typedVector, byteLength, byteOffsetD, byteValue);
}

if (value.isString()) {
auto* str = value.toStringOrNull(lexicalGlobalObject);
RETURN_IF_EXCEPTION(scope, -1);
WTF::String encodingString;
if (!encodingValue.isUndefined()) {
encodingString = encodingValue.toWTFString(lexicalGlobalObject);
RETURN_IF_EXCEPTION(scope, {});
encoding = parseEnumeration2(*lexicalGlobalObject, encodingString);
} else {
encoding = BufferEncodingType::utf8;
}

JSC::EncodedJSValue encodedBuffer = constructFromEncoding(lexicalGlobalObject, str, encoding);
auto* arrayValue = JSC::jsDynamicCast<JSC::JSUint8Array*>(JSC::JSValue::decode(encodedBuffer));
int64_t lengthValue = static_cast<int64_t>(arrayValue->byteLength());
const uint8_t* typedVectorValue = arrayValue->typedVector();
if (last) {
return lastIndexOf(typedVector, length, typedVectorValue, lengthValue, byteOffset);
} else {
return indexOf(typedVector, length, typedVectorValue, lengthValue, byteOffset);
if (valueValue.isString()) {
if (!encoding.has_value()) {
return Bun::ERR::UNKNOWN_ENCODING(scope, lexicalGlobalObject, encodingString);
}
} else if (value.isNumber()) {
uint8_t byteValue = static_cast<uint8_t>((value.toInt32(lexicalGlobalObject)) % 256);
auto* str = valueValue.toStringOrNull(lexicalGlobalObject);
RETURN_IF_EXCEPTION(scope, -1);
return indexOfString(lexicalGlobalObject, last, typedVector, byteLength, byteOffsetD, str, encoding.value());
}

if (last) {
for (int64_t i = byteOffset; i >= 0; --i) {
if (byteValue == typedVector[i]) {
return i;
}
}
} else {
const void* offset = memchr(reinterpret_cast<const void*>(typedVector + byteOffset), byteValue, length - byteOffset);
if (offset != NULL) {
return static_cast<int64_t>(static_cast<const uint8_t*>(offset) - typedVector);
}
}

return -1;
} else if (auto* arrayValue = JSC::jsDynamicCast<JSC::JSUint8Array*>(value)) {
size_t lengthValue = arrayValue->byteLength();
const uint8_t* typedVectorValue = arrayValue->typedVector();
if (last) {
return lastIndexOf(typedVector, length, typedVectorValue, lengthValue, byteOffset);
} else {
return indexOf(typedVector, length, typedVectorValue, lengthValue, byteOffset);
}
} else {
throwTypeError(lexicalGlobalObject, scope, "Invalid value type"_s);
return -1;
if (auto* array = JSC::jsDynamicCast<JSC::JSUint8Array*>(valueValue)) {
if (!encoding.has_value()) encoding = BufferEncodingType::utf8;
return indexOfBuffer(lexicalGlobalObject, last, typedVector, byteLength, byteOffsetD, array, encoding.value());
}

Bun::ERR::INVALID_ARG_TYPE(scope, lexicalGlobalObject, "value"_s, "number, string, Buffer, or Uint8Array"_s, valueValue);
return -1;
}

Expand Down
Loading