From 8b82edd5d918068f204e3ac4206d5d15f6740395 Mon Sep 17 00:00:00 2001 From: Graham Burgsma Date: Wed, 24 Jun 2020 16:25:44 -0400 Subject: [PATCH] Add support for RETURNING clause (#110) * Support RETURNING statement * Simplified SQLReturning * Updated SQLReturning serialization * Added supportsReturning to SQLDialect * Added SQLReturningBuilder to support returning statement with query builders. * Updated SQLReturningBuilder doc comments --- .../SQLKit/Builders/SQLDeleteBuilder.swift | 7 ++- .../SQLKit/Builders/SQLInsertBuilder.swift | 7 ++- .../SQLKit/Builders/SQLReturningBuilder.swift | 47 +++++++++++++++++++ .../SQLKit/Builders/SQLUpdateBuilder.swift | 7 ++- Sources/SQLKit/Query/SQLDelete.swift | 18 +++++-- Sources/SQLKit/Query/SQLInsert.swift | 19 +++++--- Sources/SQLKit/Query/SQLReturning.swift | 29 ++++++++++++ Sources/SQLKit/Query/SQLUpdate.swift | 22 ++++++--- Sources/SQLKit/SQLDialect.swift | 5 ++ Tests/SQLKitTests/SQLKitTests.swift | 22 +++++++++ Tests/SQLKitTests/Utilities.swift | 2 + 11 files changed, 164 insertions(+), 21 deletions(-) create mode 100644 Sources/SQLKit/Builders/SQLReturningBuilder.swift create mode 100644 Sources/SQLKit/Query/SQLReturning.swift diff --git a/Sources/SQLKit/Builders/SQLDeleteBuilder.swift b/Sources/SQLKit/Builders/SQLDeleteBuilder.swift index 17125c2b..2a04f04a 100644 --- a/Sources/SQLKit/Builders/SQLDeleteBuilder.swift +++ b/Sources/SQLKit/Builders/SQLDeleteBuilder.swift @@ -4,7 +4,7 @@ /// .where(\.name != "Earth").run() /// /// See `SQLQueryBuilder` and `SQLPredicateBuilder` for more information. -public final class SQLDeleteBuilder: SQLQueryBuilder, SQLPredicateBuilder { +public final class SQLDeleteBuilder: SQLQueryBuilder, SQLPredicateBuilder, SQLReturningBuilder { /// `Delete` query being built. public var delete: SQLDelete @@ -18,6 +18,11 @@ public final class SQLDeleteBuilder: SQLQueryBuilder, SQLPredicateBuilder { get { return self.delete.predicate } set { self.delete.predicate = newValue } } + + public var returning: SQLReturning? { + get { return self.delete.returning } + set { self.delete.returning = newValue } + } /// Creates a new `SQLDeleteBuilder`. public init(_ delete: SQLDelete, on database: SQLDatabase) { diff --git a/Sources/SQLKit/Builders/SQLInsertBuilder.swift b/Sources/SQLKit/Builders/SQLInsertBuilder.swift index 9136b955..2e46b000 100644 --- a/Sources/SQLKit/Builders/SQLInsertBuilder.swift +++ b/Sources/SQLKit/Builders/SQLInsertBuilder.swift @@ -4,7 +4,7 @@ /// .value(earth).run() /// /// See `SQLQueryBuilder` for more information. -public final class SQLInsertBuilder: SQLQueryBuilder { +public final class SQLInsertBuilder: SQLQueryBuilder, SQLReturningBuilder { /// `Insert` query being built. public var insert: SQLInsert @@ -15,6 +15,11 @@ public final class SQLInsertBuilder: SQLQueryBuilder { public var query: SQLExpression { return self.insert } + + public var returning: SQLReturning? { + get { return self.insert.returning } + set { self.insert.returning = newValue } + } /// Creates a new `SQLInsertBuilder`. public init(_ insert: SQLInsert, on database: SQLDatabase) { diff --git a/Sources/SQLKit/Builders/SQLReturningBuilder.swift b/Sources/SQLKit/Builders/SQLReturningBuilder.swift new file mode 100644 index 00000000..8451979e --- /dev/null +++ b/Sources/SQLKit/Builders/SQLReturningBuilder.swift @@ -0,0 +1,47 @@ +public protocol SQLReturningBuilder: SQLQueryBuilder { + var returning: SQLReturning? { get set } +} + +extension SQLReturningBuilder { + /// Specify a list of columns to be part of the result set of the query. + /// Each provided name is a string assumed to be a valid SQL identifier and + /// is not qualified. + /// + /// - parameters: + /// - columns: The names of the columns to return. + /// - returns: Self for chaining. + public func returning(_ columns: String...) -> Self { + let sqlColumns = columns.map { (column) -> SQLColumn in + if column == "*" { + return SQLColumn(SQLLiteral.all) + } else { + return SQLColumn(column) + } + } + + self.returning = .init(sqlColumns) + return self + } + + /// Specify a list of columns to be returned as the result of the query. + /// Each input is an arbitrary expression. + /// + /// - parameters: + /// - columns: A list of expressions identifying the columns to return. + /// - returns: Self for chaining. + public func returning(_ columns: SQLExpression...) -> Self { + self.returning = .init(columns) + return self + } + + /// Specify a list of columns to be returned as the result of the query. + /// Each input is an arbitrary expression. + /// + /// - parameters: + /// - column: An array of expressions identifying the columns to return. + /// - returns: Self for chaining. + public func returning(_ columns: [SQLExpression]) -> Self { + self.returning = .init(columns) + return self + } +} diff --git a/Sources/SQLKit/Builders/SQLUpdateBuilder.swift b/Sources/SQLKit/Builders/SQLUpdateBuilder.swift index b152b0d9..f6005f7e 100644 --- a/Sources/SQLKit/Builders/SQLUpdateBuilder.swift +++ b/Sources/SQLKit/Builders/SQLUpdateBuilder.swift @@ -6,7 +6,7 @@ /// .run() /// /// See `SQLQueryBuilder` and `SQLPredicateBuilder` for more information. -public final class SQLUpdateBuilder: SQLQueryBuilder, SQLPredicateBuilder { +public final class SQLUpdateBuilder: SQLQueryBuilder, SQLPredicateBuilder, SQLReturningBuilder { /// `Update` query being built. public var update: SQLUpdate @@ -20,6 +20,11 @@ public final class SQLUpdateBuilder: SQLQueryBuilder, SQLPredicateBuilder { get { return self.update.predicate } set { self.update.predicate = newValue } } + + public var returning: SQLReturning? { + get { return self.update.returning } + set { self.update.returning = newValue } + } /// Creates a new `SQLDeleteBuilder`. public init(_ update: SQLUpdate, on database: SQLDatabase) { diff --git a/Sources/SQLKit/Query/SQLDelete.swift b/Sources/SQLKit/Query/SQLDelete.swift index ecd72468..3918af61 100644 --- a/Sources/SQLKit/Query/SQLDelete.swift +++ b/Sources/SQLKit/Query/SQLDelete.swift @@ -9,6 +9,9 @@ public struct SQLDelete: SQLExpression { /// then only those rows for which the WHERE clause boolean expression is true are deleted. Rows for which /// the expression is false or NULL are retained. public var predicate: SQLExpression? + + /// Optionally append a `RETURNING` clause that, where supported, returns the supplied supplied columns. + public var returning: SQLReturning? /// Creates a new `SQLDelete`. public init(table: SQLExpression) { @@ -16,11 +19,16 @@ public struct SQLDelete: SQLExpression { } public func serialize(to serializer: inout SQLSerializer) { - serializer.write("DELETE FROM ") - self.table.serialize(to: &serializer) - if let predicate = self.predicate { - serializer.write(" WHERE ") - predicate.serialize(to: &serializer) + serializer.statement { + $0.append("DELETE FROM") + $0.append(self.table) + if let predicate = self.predicate { + $0.append("WHERE") + $0.append(predicate) + } + if let returning = self.returning { + $0.append(returning) + } } } } diff --git a/Sources/SQLKit/Query/SQLInsert.swift b/Sources/SQLKit/Query/SQLInsert.swift index 7607929d..06ab25a1 100644 --- a/Sources/SQLKit/Query/SQLInsert.swift +++ b/Sources/SQLKit/Query/SQLInsert.swift @@ -12,6 +12,9 @@ public struct SQLInsert: SQLExpression { /// /// Use the `DEFAULT` literal to omit a value and that is specified as a column. public var values: [[SQLExpression]] + + /// Optionally append a `RETURNING` clause that, where supported, returns the supplied supplied columns. + public var returning: SQLReturning? /// Creates a new `SQLInsert`. public init(table: SQLExpression) { @@ -21,11 +24,15 @@ public struct SQLInsert: SQLExpression { } public func serialize(to serializer: inout SQLSerializer) { - serializer.write("INSERT INTO ") - self.table.serialize(to: &serializer) - serializer.write(" ") - SQLGroupExpression(self.columns).serialize(to: &serializer) - serializer.write(" VALUES ") - SQLList(self.values.map(SQLGroupExpression.init)).serialize(to: &serializer) + serializer.statement { + $0.append("INSERT INTO") + $0.append(self.table) + $0.append(SQLGroupExpression(self.columns)) + $0.append("VALUES") + $0.append(SQLList(self.values.map(SQLGroupExpression.init))) + if let returning = self.returning { + $0.append(returning) + } + } } } diff --git a/Sources/SQLKit/Query/SQLReturning.swift b/Sources/SQLKit/Query/SQLReturning.swift new file mode 100644 index 00000000..b43b2180 --- /dev/null +++ b/Sources/SQLKit/Query/SQLReturning.swift @@ -0,0 +1,29 @@ +/// `RETURNING ...` statement. +/// +public struct SQLReturning: SQLExpression { + public var columns: [SQLExpression] + + /// Creates a new `SQLReturning`. + public init(_ column: SQLColumn) { + self.columns = [column] + } + + /// Creates a new `SQLReturning`. + public init(_ columns: [SQLExpression]) { + self.columns = columns + } + + public func serialize(to serializer: inout SQLSerializer) { + guard serializer.dialect.supportsReturning else { + serializer.database.logger.warning("\(serializer.dialect.name) does not support 'RETURNING' clause, skipping.") + return + } + + guard !columns.isEmpty else { return } + + serializer.statement { + $0.append("RETURNING") + $0.append(SQLList(columns)) + } + } +} diff --git a/Sources/SQLKit/Query/SQLUpdate.swift b/Sources/SQLKit/Query/SQLUpdate.swift index 7767b510..b848d01d 100644 --- a/Sources/SQLKit/Query/SQLUpdate.swift +++ b/Sources/SQLKit/Query/SQLUpdate.swift @@ -10,6 +10,9 @@ public struct SQLUpdate: SQLExpression { /// Optional predicate to limit updated rows. public var predicate: SQLExpression? + + /// Optionally append a `RETURNING` clause that, where supported, returns the supplied supplied columns. + public var returning: SQLReturning? /// Creates a new `SQLUpdate`. public init(table: SQLExpression) { @@ -19,13 +22,18 @@ public struct SQLUpdate: SQLExpression { } public func serialize(to serializer: inout SQLSerializer) { - serializer.write("UPDATE ") - self.table.serialize(to: &serializer) - serializer.write(" SET ") - SQLList(self.values).serialize(to: &serializer) - if let predicate = self.predicate { - serializer.write(" WHERE ") - predicate.serialize(to: &serializer) + serializer.statement { + $0.append("UPDATE") + $0.append(self.table) + $0.append("SET") + $0.append(SQLList(self.values)) + if let predicate = self.predicate { + $0.append("WHERE") + $0.append(predicate) + } + if let returning = self.returning { + $0.append(returning) + } } } } diff --git a/Sources/SQLKit/SQLDialect.swift b/Sources/SQLKit/SQLDialect.swift index 32775871..ade6cb90 100644 --- a/Sources/SQLKit/SQLDialect.swift +++ b/Sources/SQLKit/SQLDialect.swift @@ -11,6 +11,7 @@ public protocol SQLDialect { var autoIncrementFunction: SQLExpression? { get } var enumSyntax: SQLEnumSyntax { get } var supportsDropBehavior: Bool { get } + var supportsReturning: Bool { get } var triggerSyntax: SQLTriggerSyntax { get } var alterTableSyntax: SQLAlterTableSyntax { get } func customDataType(for dataType: SQLDataType) -> SQLExpression? @@ -135,6 +136,10 @@ extension SQLDialect { return false } + public var supportsReturning: Bool { + return false + } + public var triggerSyntax: SQLTriggerSyntax { return SQLTriggerSyntax() } diff --git a/Tests/SQLKitTests/SQLKitTests.swift b/Tests/SQLKitTests/SQLKitTests.swift index ffe8b8fa..d4552d63 100644 --- a/Tests/SQLKitTests/SQLKitTests.swift +++ b/Tests/SQLKitTests/SQLKitTests.swift @@ -257,6 +257,28 @@ final class SQLKitTests: XCTestCase { XCTAssertEqual(db.results[0], "UPDATE `planets` SET `moons` = `moons` + 1 WHERE `best_at_space` >= ?") } + + func testReturning() throws { + let db = TestDatabase() + + try db.insert(into: "planets") + .columns("name") + .values("Jupiter") + .returning("id", "name") + .run().wait() + XCTAssertEqual(db.results[0], "INSERT INTO `planets` (`name`) VALUES (?) RETURNING `id`, `name`") + + try db.update("planets") + .set("name", to: "Jupiter") + .returning(SQLColumn("name", table: "planets")) + .run().wait() + XCTAssertEqual(db.results[1], "UPDATE `planets` SET `name` = ? RETURNING `planets`.`name`") + + try db.delete(from: "planets") + .returning("*") + .run().wait() + XCTAssertEqual(db.results[2], "DELETE FROM `planets` RETURNING *") + } } // MARK: Table Creation diff --git a/Tests/SQLKitTests/Utilities.swift b/Tests/SQLKitTests/Utilities.swift index b8979df3..2253e124 100644 --- a/Tests/SQLKitTests/Utilities.swift +++ b/Tests/SQLKitTests/Utilities.swift @@ -79,6 +79,8 @@ struct GenericDialect: SQLDialect { var supportsIfExists: Bool = true + var supportsReturning: Bool = true + var identifierQuote: SQLExpression { return SQLRaw("`") }