diff --git a/src/js/bun/sql.ts b/src/js/bun/sql.ts index abe2a973cc6793..a619c9c12129d1 100644 --- a/src/js/bun/sql.ts +++ b/src/js/bun/sql.ts @@ -1,3 +1,5 @@ +const { hideFromStack } = require("internal/shared"); + const enum QueryStatus { active = 1 << 1, cancelled = 1 << 2, @@ -15,6 +17,11 @@ const enum SSLMode { verify_full = 4, } +function connectionClosedError() { + return $ERR_POSTGRES_CONNECTION_CLOSED("Connection closed"); +} +hideFromStack(connectionClosedError); + class SQLResultArray extends PublicArray { static [Symbol.toStringTag] = "SQLResults"; @@ -33,6 +40,7 @@ const _run = Symbol("run"); const _queryStatus = Symbol("status"); const _handler = Symbol("handler"); const PublicPromise = Promise; +type TransactionCallback = (sql: (strings: string, ...values: any[]) => Query) => Promise; const { createConnection: _createConnection, @@ -654,9 +662,103 @@ function SQL(o) { return pendingSQL(strings, values); } + sql.begin = async (fn: TransactionCallback) => { + /* + BEGIN; -- works on POSTGRES, MySQL, and SQLite (need to change to BEGIN TRANSACTION on MSSQL) + + -- Create a SAVEPOINT + SAVEPOINT my_savepoint; -- works on POSTGRES, MySQL, and SQLite (need to change to SAVE TRANSACTION on MSSQL) + + -- QUERY + + -- Roll back to SAVEPOINT if needed + ROLLBACK TO SAVEPOINT my_savepoint; -- works on POSTGRES, MySQL, and SQLite (need to change to ROLLBACK TRANSACTION on MSSQL) + + -- Release the SAVEPOINT + RELEASE SAVEPOINT my_savepoint; -- works on POSTGRES, MySQL, and SQLite (MSSQL dont have RELEASE SAVEPOINT you just need to transaction again) + + -- Commit the transaction + COMMIT; -- works on POSTGRES, MySQL, and SQLite (need to change to COMMIT TRANSACTION on MSSQL) + -- or rollback everything + ROLLBACK; -- works on POSTGRES, MySQL, and SQLite (need to change to ROLLBACK TRANSACTION on MSSQL) + + */ + + // this is a big TODO we need to make sure that each created query actually uses the same connection or fails + let current_connection; + let savepoints = 0; + try { + if (closed) { + throw connectionClosedError(); + } + if (!$isCallable(fn)) { + throw $ERR_INVALID_ARG_VALUE("fn", fn, "must be a function"); + } + //@ts-ignore + await sql("BEGIN"); + // keep track of the connection that is being used + current_connection = connection; + + // we need a function able to check for the current connection + const sql_with_savepoint = function (strings, ...values) { + return sql(strings, ...values); + }; + // dirt copy of the sql object + for (const key in sql) { + sql_with_savepoint[key] = sql[key]; + } + // this version accepts savepoints + sql_with_savepoint.savepoint = async (fn: TransactionCallback, name?: string) => { + let callback = fn; + + if (closed || current_connection !== connection) { + throw connectionClosedError(); + } + if ($isCallable(name)) { + callback = name as unknown as TransactionCallback; + name = ""; + } + if (!$isCallable(callback)) { + throw $ERR_INVALID_ARG_VALUE("fn", callback, "must be a function"); + } + // matchs the format of the savepoint name in postgres package + const save_point_name = `s${savepoints++}${name ? `_${name}` : ""}`; + + try { + await sql_with_savepoint`SAVEPOINT ${save_point_name}`; + const result = await callback(sql_with_savepoint); + if (!closed && current_connection === connection) { + await sql_with_savepoint(`RELEASE SAVEPOINT ${save_point_name}`); + } else { + throw connectionClosedError(); + } + return result; + } catch (err) { + if (!closed && current_connection === connection) { + await sql_with_savepoint(`ROLLBACK TO SAVEPOINT ${save_point_name}`); + } + throw err; + } + }; + + const transaction_result = await fn(sql_with_savepoint); + if (!closed && current_connection === connection) { + await sql("COMMIT"); + } else { + throw connectionClosedError(); + } + return transaction_result; + } catch (err) { + if (current_connection && !closed && current_connection === connection) { + await sql("ROLLBACK"); + } + throw err; + } + }; + sql.connect = () => { if (closed) { - return Promise.reject(new Error("Connection closed")); + return Promise.reject(connectionClosedError()); } if (connected) { @@ -697,7 +799,7 @@ function SQL(o) { sql.then = () => { if (closed) { - return Promise.reject(new Error("Connection closed")); + return Promise.reject(connectionClosedError()); } if (connected) { diff --git a/src/sql/postgres.zig b/src/sql/postgres.zig index c0f2bbef847e23..b93000fefcb912 100644 --- a/src/sql/postgres.zig +++ b/src/sql/postgres.zig @@ -484,9 +484,14 @@ pub const PostgresSQLQuery = struct { pub fn call(globalThis: *JSC.JSGlobalObject, callframe: *JSC.CallFrame) bun.JSError!JSC.JSValue { const arguments = callframe.arguments_old(4).slice(); - const query = arguments[0]; - const values = arguments[1]; - const columns = arguments[3]; + var args = JSC.Node.ArgumentsSlice.init(globalThis.bunVM(), arguments); + defer args.deinit(); + const query = args.nextEat() orelse { + return globalThis.throw("query must be a string", .{}); + }; + const values = args.nextEat() orelse { + return globalThis.throw("values must be an array", .{}); + }; if (!query.isString()) { return globalThis.throw("query must be a string", .{}); @@ -496,7 +501,9 @@ pub const PostgresSQLQuery = struct { return globalThis.throw("values must be an array", .{}); } - const pending_value = arguments[2]; + const pending_value = args.nextEat() orelse .undefined; + const columns = args.nextEat() orelse .undefined; + if (!pending_value.jsType().isArrayLike()) { return globalThis.throwInvalidArgumentType("query", "pendingValue", "Array"); }