From b841d475379a34b81d3130fde0de1d69a6819180 Mon Sep 17 00:00:00 2001 From: Dimitri Bouniol Date: Tue, 9 Jul 2024 03:44:34 -0700 Subject: [PATCH 1/3] Added support for conditional response compression --- .../HTTPResponseCompressor.swift | 53 ++- .../HTTPResponseCompressorTest.swift | 301 +++++++++++++++++- 2 files changed, 334 insertions(+), 20 deletions(-) diff --git a/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift b/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift index 7b251792..7477723e 100644 --- a/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift +++ b/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift @@ -57,6 +57,13 @@ private func qValueFromHeader(_ text: S) -> Float { /// ahead-of-time instead of dynamically, could be a waste of CPU time and latency for relatively minimal /// benefit. This channel handler should be present in the pipeline only for dynamically-generated and /// highly-compressible content, which will see the biggest benefits from streaming compression. +/// +/// The compressor optionally accepts a predicate to help it determine on a per-request basis if compression +/// should be used, even if the client requests it for the request. This could be used to conditionally and statelessly +/// enable compression based on resource types, or by emitting and checking for marker headers as needed. +/// Since the predicate is always called, it can also be used to clean up those marker headers if compression was +/// not actually supported for any reason (ie. the client didn't provide compatible `Accept` headers, or the +/// response was missing a body due to a special status code being used) public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChannelHandler { /// This class accepts `HTTPServerRequestPart` inbound public typealias InboundIn = HTTPServerRequestPart @@ -66,6 +73,18 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne public typealias OutboundIn = HTTPServerResponsePart /// This class emits `HTTPServerResponsePart` outbound. public typealias OutboundOut = HTTPServerResponsePart + + /// A closure that accepts a response header, optionally modifies it, and returns `true` if the response it belongs to should be compressed. + /// + /// - Parameter responseHeaders: The headers that will be used for the response. These can be modified as needed at this stage, to clean up any marker headers used to statelessly determine if compression should occur, and the new headers will be used when writing the response. Compression headers are not yet provided and should not be set; ``HTTPResponseCompressor`` will set them accordingly based on the result of this predicate. + /// - Parameter isCompressionSupported: Set to `true` if the client requested compatible compression, and if the HTTP response supports it, otherwise `false`. + /// - Returns: Return `true` if the compressor should proceed to compress the response, or `false` if the response should not be compressed. + /// + /// - Note: Returning `true` when compression is not supported will not enable compression, and the modified headers will always be used. + public typealias ResponseCompressionPredicate = ( + _ responseHeaders: inout HTTPResponseHead, + _ isCompressionSupported: Bool + ) -> Bool /// Errors which can occur when compressing public enum CompressionError: Error { @@ -84,11 +103,21 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne private var pendingWritePromise: EventLoopPromise! private let initialByteBufferCapacity: Int + private let responseCompressionPredicate: ResponseCompressionPredicate? - /// Initialise a ``HTTPResponseCompressor`` + /// Initialize a ``HTTPResponseCompressor``. + /// - Parameter initialByteBufferCapacity: Initial size of buffer to allocate when hander is first added. + public convenience init(initialByteBufferCapacity: Int = 1024) { + // TODO: This version is kept around for backwards compatibility and should be merged with the signature below in the next major version: https://github.com/apple/swift-nio-extras/issues/226 + self.init(initialByteBufferCapacity: initialByteBufferCapacity, responseCompressionPredicate: nil) + } + + /// Initialize a ``HTTPResponseCompressor``. /// - Parameter initialByteBufferCapacity: Initial size of buffer to allocate when hander is first added. - public init(initialByteBufferCapacity: Int = 1024) { + /// - Parameter responseCompressionPredicate: The predicate used to determine if the response should be compressed or not based on its headers. Defaults to `nil`, which will compress every response this handler sees. This predicate is always called whether the client supports compression for this response or not, so it can be used to clean up any marker headers you may use to determine if compression should be performed or not. Please see ``ResponseCompressionPredicate`` for more details. + public init(initialByteBufferCapacity: Int = 1024, responseCompressionPredicate: ResponseCompressionPredicate? = nil) { self.initialByteBufferCapacity = initialByteBufferCapacity + self.responseCompressionPredicate = responseCompressionPredicate self.compressor = NIOCompression.Compressor() } @@ -118,17 +147,28 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne let httpData = unwrapOutboundIn(data) switch httpData { case .head(var responseHead): - guard let algorithm = compressionAlgorithm(), responseHead.status.mayHaveResponseBody else { + /// Grab the algorithm to use from the bottom of the accept queue, which will help determine if we support compression for this response or not. + let algorithm = compressionAlgorithm() + let requestSupportsCompression = algorithm != nil && responseHead.status.mayHaveResponseBody + + /// If a predicate was set, ask it if we should compress when compression is supported, and give the predicate a chance to clean up any marker headers that may have been set even if compression were not supported. + let predicateSupportsCompression = responseCompressionPredicate?(&responseHead, requestSupportsCompression) ?? true + + /// Make sure that compression should proceed, otherwise stop here and supply the response headers before configuring the compressor. + guard let algorithm, requestSupportsCompression, predicateSupportsCompression else { context.write(wrapOutboundOut(.head(responseHead)), promise: promise) return } - // Previous handlers in the pipeline might have already set this header even though - // they should not as it is compressor responsibility to decide what encoding to use + + /// Previous handlers in the pipeline might have already set this header even though they should not have as it is compressor responsibility to decide what encoding to use. responseHead.headers.replaceOrAdd(name: "Content-Encoding", value: algorithm.description) + + /// Initialize the compressor and write the header data, which marks the compressor as "active" allowing the `.body` and `.end` cases to properly compress the response rather than passing it as is. compressor.initialize(encoding: algorithm) pendingResponse.bufferResponseHead(responseHead) pendingWritePromise.futureResult.cascade(to: promise) case .body(let body): + /// We already determined if compression should occur based on the `.head` case above, so here we simply need to check if the compressor is active or not to determine if we should compress the body chunks or stream them as is. if compressor.isActive { pendingResponse.bufferBodyPart(body) pendingWritePromise.futureResult.cascade(to: promise) @@ -136,13 +176,12 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne context.write(data, promise: promise) } case .end: - // This compress is not done in flush because we need to be done with the - // compressor now. guard compressor.isActive else { context.write(data, promise: promise) return } + /// Compress any trailers and finalize the response. Note that this compression stage is not done in `flush()` because we need to clean up the compressor state to be ready for the next response that can come in on the same handler. pendingResponse.bufferResponseEnd(httpData) pendingWritePromise.futureResult.cascade(to: promise) emitPendingWrites(context: context) diff --git a/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift b/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift index 44a2ef08..78810fb1 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift @@ -234,10 +234,18 @@ class HTTPResponseCompressorTest: XCTestCase { XCTAssertEqual(expectedResponse, outputBuffer) } - private func assertDeflatedResponse(channel: EmbeddedChannel, writeStrategy: WriteStrategy = .once) throws { + private func assertDeflatedResponse( + channel: EmbeddedChannel, + writeStrategy: WriteStrategy = .once, + responseHeaders: HTTPHeaders = [:], + assertHeaders: HTTPHeaders? = nil + ) throws { let bodySize = 2048 - let response = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), - status: .ok) + let response = HTTPResponseHead( + version: HTTPVersion(major: 1, minor: 1), + status: .ok, + headers: responseHeaders + ) let body = [UInt8](repeating: 60, count: bodySize) var bodyBuffer = channel.allocator.buffer(capacity: bodySize) bodyBuffer.writeBytes(body) @@ -265,17 +273,28 @@ class HTTPResponseCompressorTest: XCTestCase { XCTAssertEqual(compressedResponse.headers[canonicalForm: "transfer-encoding"], ["chunked"]) } + if let assertHeaders { + XCTAssertEqual(compressedResponse.headers, assertHeaders) + } + assertDecompressedResponseMatches(responseData: &compressedBody, expectedResponse: bodyBuffer, allocator: channel.allocator, decompressor: z_stream.decompressDeflate) } - private func assertGzippedResponse(channel: EmbeddedChannel, writeStrategy: WriteStrategy = .once, additionalHeaders: HTTPHeaders = HTTPHeaders()) throws { + private func assertGzippedResponse( + channel: EmbeddedChannel, + writeStrategy: WriteStrategy = .once, + responseHeaders: HTTPHeaders = [:], + assertHeaders: HTTPHeaders? = nil + ) throws { let bodySize = 2048 - var response = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), - status: .ok) - response.headers = additionalHeaders + let response = HTTPResponseHead( + version: HTTPVersion(major: 1, minor: 1), + status: .ok, + headers: responseHeaders + ) let body = [UInt8](repeating: 60, count: bodySize) var bodyBuffer = channel.allocator.buffer(capacity: bodySize) bodyBuffer.writeBytes(body) @@ -303,16 +322,28 @@ class HTTPResponseCompressorTest: XCTestCase { XCTAssertEqual(compressedResponse.headers[canonicalForm: "transfer-encoding"], ["chunked"]) } + if let assertHeaders { + XCTAssertEqual(compressedResponse.headers, assertHeaders) + } + assertDecompressedResponseMatches(responseData: &compressedBody, expectedResponse: bodyBuffer, allocator: channel.allocator, decompressor: z_stream.decompressGzip) } - private func assertUncompressedResponse(channel: EmbeddedChannel, writeStrategy: WriteStrategy = .once) throws { + private func assertUncompressedResponse( + channel: EmbeddedChannel, + writeStrategy: WriteStrategy = .once, + responseHeaders: HTTPHeaders = [:], + assertHeaders: HTTPHeaders? = nil + ) throws { let bodySize = 2048 - let response = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), - status: .ok) + let response = HTTPResponseHead( + version: HTTPVersion(major: 1, minor: 1), + status: .ok, + headers: responseHeaders + ) let body = [UInt8](repeating: 60, count: bodySize) var bodyBuffer = channel.allocator.buffer(capacity: bodySize) bodyBuffer.writeBytes(body) @@ -330,14 +361,19 @@ class HTTPResponseCompressorTest: XCTestCase { var compressedChunks = data.1 let uncompressedBody = compressedChunks[0].merge(compressedChunks[1...]) XCTAssertEqual(compressedResponse.headers[canonicalForm: "content-encoding"], []) + if let assertHeaders { + XCTAssertEqual(compressedResponse.headers, assertHeaders) + } XCTAssertEqual(uncompressedBody.readableBytes, 2048) XCTAssertEqual(uncompressedBody, bodyBuffer) } - private func compressionChannel() throws -> EmbeddedChannel { + private func compressionChannel( + compressor: HTTPResponseCompressor = HTTPResponseCompressor() + ) throws -> EmbeddedChannel { let channel = EmbeddedChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(HTTPResponseEncoder(), name: "encoder").wait()) - XCTAssertNoThrow(try channel.pipeline.addHandler(HTTPResponseCompressor(), name: "compressor").wait()) + XCTAssertNoThrow(try channel.pipeline.addHandler(compressor, name: "compressor").wait()) return channel } @@ -350,6 +386,17 @@ class HTTPResponseCompressorTest: XCTestCase { try sendRequest(acceptEncoding: "deflate", channel: channel) try assertDeflatedResponse(channel: channel) } + + func testExplicitInitialByteBufferCapacity() throws { + /// This test it to make sure there is no ambiguity choosing an initializer. + let channel = try compressionChannel(compressor: HTTPResponseCompressor(initialByteBufferCapacity: 2048)) + defer { + XCTAssertNoThrow(try channel.finish()) + } + + try sendRequest(acceptEncoding: "deflate", channel: channel) + try assertDeflatedResponse(channel: channel) + } func testCanCompressSimpleBodiesGzip() throws { let channel = try compressionChannel() @@ -504,7 +551,7 @@ class HTTPResponseCompressorTest: XCTestCase { } try sendRequest(acceptEncoding: "deflate;q=2.2, gzip;q=0.3", channel: channel) - try assertGzippedResponse(channel: channel, additionalHeaders: HTTPHeaders([("Content-Encoding", "deflate")])) + try assertGzippedResponse(channel: channel, responseHeaders: HTTPHeaders([("Content-Encoding", "deflate")])) } func testRemovingHandlerFailsPendingWrites() throws { @@ -688,6 +735,234 @@ class HTTPResponseCompressorTest: XCTestCase { } } } + + func testConditionalCompressionEnabled() throws { + let predicateWasCalled = expectation(description: "Predicate was called") + let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in + predicateWasCalled.fulfill() + XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(isCompressionSupported, true) + return true + } + + let channel = try compressionChannel(compressor: compressor) + defer { + XCTAssertNoThrow(try channel.finish()) + } + + try sendRequest(acceptEncoding: "deflate", channel: channel) + try assertDeflatedResponse( + channel: channel, + responseHeaders: ["Content-Type" : "json"], + assertHeaders: [ + "Content-Type" : "json", + "Content-Encoding" : "deflate", + "Content-Length" : "23", + ] + ) + + waitForExpectations(timeout: 0) + } + + func testUnsupportedRequestConditionalCompressionEnabled() throws { + let predicateWasCalled = expectation(description: "Predicate was called") + let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in + predicateWasCalled.fulfill() + XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(isCompressionSupported, false) + return true + } + + let channel = try compressionChannel(compressor: compressor) + defer { + XCTAssertNoThrow(try channel.finish()) + } + + try sendRequest(acceptEncoding: nil, channel: channel) + try assertUncompressedResponse( + channel: channel, + responseHeaders: ["Content-Type" : "json"], + assertHeaders: [ + "Content-Type" : "json", + "transfer-encoding" : "chunked", + ] + ) + + waitForExpectations(timeout: 0) + } + + func testUnsupportedStatusConditionalCompressionEnabled() throws { + let predicateWasCalled = expectation(description: "Predicate was called") + let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in + predicateWasCalled.fulfill() + XCTAssertEqual(responseHeaders.status, .notModified) + XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(isCompressionSupported, false) + return true + } + + let channel = EmbeddedChannel() + XCTAssertNoThrow(try channel.pipeline.addHandler(compressor).wait()) + defer { + XCTAssertNoThrow(try channel.finish()) + } + + try sendRequest(acceptEncoding: "deflate", channel: channel) + + let head = HTTPResponseHead( + version: .init(major: 1, minor: 1), + status: .notModified, + headers: ["Content-Type" : "json"] + ) + try channel.writeOutbound(HTTPServerResponsePart.head(head)) + try channel.writeOutbound(HTTPServerResponsePart.end(nil)) + + while let part = try channel.readOutbound(as: HTTPServerResponsePart.self) { + switch part { + case .head(let head): + XCTAssertEqual(head.headers[canonicalForm: "content-encoding"], []) + case .body: + XCTFail("Unexpected body") + case .end: break + } + } + + waitForExpectations(timeout: 0) + } + + func testConditionalCompressionDisabled() throws { + let predicateWasCalled = expectation(description: "Predicate was called") + let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in + predicateWasCalled.fulfill() + XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(isCompressionSupported, true) + return false + } + + let channel = try compressionChannel(compressor: compressor) + defer { + XCTAssertNoThrow(try channel.finish()) + } + + try sendRequest(acceptEncoding: "deflate", channel: channel) + try assertUncompressedResponse( + channel: channel, + responseHeaders: ["Content-Type" : "json"], + assertHeaders: [ + "Content-Type" : "json", + "transfer-encoding" : "chunked", + ] + ) + + waitForExpectations(timeout: 0) + } + + func testUnsupportedRequestConditionalCompressionDisabled() throws { + let predicateWasCalled = expectation(description: "Predicate was called") + let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in + predicateWasCalled.fulfill() + XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(isCompressionSupported, false) + return false + } + + let channel = try compressionChannel(compressor: compressor) + defer { + XCTAssertNoThrow(try channel.finish()) + } + + try sendRequest(acceptEncoding: nil, channel: channel) + try assertUncompressedResponse( + channel: channel, + responseHeaders: ["Content-Type" : "json"], + assertHeaders: [ + "Content-Type" : "json", + "transfer-encoding" : "chunked", + ] + ) + + waitForExpectations(timeout: 0) + } + + func testUnsupportedStatusConditionalCompressionDisabled() throws { + let predicateWasCalled = expectation(description: "Predicate was called") + let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in + predicateWasCalled.fulfill() + XCTAssertEqual(responseHeaders.status, .notModified) + XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(isCompressionSupported, false) + return false + } + + let channel = EmbeddedChannel() + XCTAssertNoThrow(try channel.pipeline.addHandler(compressor).wait()) + defer { + XCTAssertNoThrow(try channel.finish()) + } + + try sendRequest(acceptEncoding: "deflate", channel: channel) + + let head = HTTPResponseHead( + version: .init(major: 1, minor: 1), + status: .notModified, + headers: ["Content-Type" : "json"] + ) + try channel.writeOutbound(HTTPServerResponsePart.head(head)) + try channel.writeOutbound(HTTPServerResponsePart.end(nil)) + + while let part = try channel.readOutbound(as: HTTPServerResponsePart.self) { + switch part { + case .head(let head): + XCTAssertEqual(head.headers[canonicalForm: "content-encoding"], []) + case .body: + XCTFail("Unexpected body") + case .end: break + } + } + + waitForExpectations(timeout: 0) + } + + func testConditionalCompressionModifiedHeaders() throws { + let predicateWasCalled = expectation(description: "Predicate was called") + predicateWasCalled.expectedFulfillmentCount = 2 + let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in + predicateWasCalled.fulfill() + let isEnabled = responseHeaders.headers[canonicalForm: "x-compression"].first == "enable" + XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json", "X-Compression" : isEnabled ? "enable" : "disable"]) + responseHeaders.headers.remove(name: "X-Compression") + XCTAssertEqual(isCompressionSupported, true) + return isEnabled + } + + let channel = try compressionChannel(compressor: compressor) + defer { + XCTAssertNoThrow(try channel.finish()) + } + + try sendRequest(acceptEncoding: "deflate", channel: channel) + try assertDeflatedResponse( + channel: channel, + responseHeaders: ["Content-Type" : "json", "X-Compression" : "enable"], + assertHeaders: [ + "Content-Type" : "json", + "Content-Encoding" : "deflate", + "Content-Length" : "23", + ] + ) + + try sendRequest(acceptEncoding: "deflate", channel: channel) + try assertUncompressedResponse( + channel: channel, + responseHeaders: ["Content-Type" : "json", "X-Compression" : "disable"], + assertHeaders: [ + "Content-Type" : "json", + "transfer-encoding" : "chunked", + ] + ) + + waitForExpectations(timeout: 0) + } } extension EventLoopFuture { From 999fd2cc2a57362bce99509a816d2cd76bb68715 Mon Sep 17 00:00:00 2001 From: Dimitri Bouniol Date: Wed, 17 Jul 2024 13:25:41 -0700 Subject: [PATCH 2/3] Updated expectations to be fulfilled via a defer --- .../HTTPResponseCompressorTest.swift | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift b/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift index 78810fb1..343142cd 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift @@ -739,7 +739,7 @@ class HTTPResponseCompressorTest: XCTestCase { func testConditionalCompressionEnabled() throws { let predicateWasCalled = expectation(description: "Predicate was called") let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in - predicateWasCalled.fulfill() + defer { predicateWasCalled.fulfill() } XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) XCTAssertEqual(isCompressionSupported, true) return true @@ -767,7 +767,7 @@ class HTTPResponseCompressorTest: XCTestCase { func testUnsupportedRequestConditionalCompressionEnabled() throws { let predicateWasCalled = expectation(description: "Predicate was called") let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in - predicateWasCalled.fulfill() + defer { predicateWasCalled.fulfill() } XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) XCTAssertEqual(isCompressionSupported, false) return true @@ -794,7 +794,7 @@ class HTTPResponseCompressorTest: XCTestCase { func testUnsupportedStatusConditionalCompressionEnabled() throws { let predicateWasCalled = expectation(description: "Predicate was called") let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in - predicateWasCalled.fulfill() + defer { predicateWasCalled.fulfill() } XCTAssertEqual(responseHeaders.status, .notModified) XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) XCTAssertEqual(isCompressionSupported, false) @@ -833,7 +833,7 @@ class HTTPResponseCompressorTest: XCTestCase { func testConditionalCompressionDisabled() throws { let predicateWasCalled = expectation(description: "Predicate was called") let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in - predicateWasCalled.fulfill() + defer { predicateWasCalled.fulfill() } XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) XCTAssertEqual(isCompressionSupported, true) return false @@ -860,7 +860,7 @@ class HTTPResponseCompressorTest: XCTestCase { func testUnsupportedRequestConditionalCompressionDisabled() throws { let predicateWasCalled = expectation(description: "Predicate was called") let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in - predicateWasCalled.fulfill() + defer { predicateWasCalled.fulfill() } XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) XCTAssertEqual(isCompressionSupported, false) return false @@ -887,7 +887,7 @@ class HTTPResponseCompressorTest: XCTestCase { func testUnsupportedStatusConditionalCompressionDisabled() throws { let predicateWasCalled = expectation(description: "Predicate was called") let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in - predicateWasCalled.fulfill() + defer { predicateWasCalled.fulfill() } XCTAssertEqual(responseHeaders.status, .notModified) XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) XCTAssertEqual(isCompressionSupported, false) @@ -927,7 +927,7 @@ class HTTPResponseCompressorTest: XCTestCase { let predicateWasCalled = expectation(description: "Predicate was called") predicateWasCalled.expectedFulfillmentCount = 2 let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in - predicateWasCalled.fulfill() + defer { predicateWasCalled.fulfill() } let isEnabled = responseHeaders.headers[canonicalForm: "x-compression"].first == "enable" XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json", "X-Compression" : isEnabled ? "enable" : "disable"]) responseHeaders.headers.remove(name: "X-Compression") From eab13037326be29e9f06213fe34d8fb281e168c5 Mon Sep 17 00:00:00 2001 From: Dimitri Bouniol Date: Wed, 17 Jul 2024 13:23:43 -0700 Subject: [PATCH 3/3] Updated compression response predicate to return an intent enum rather than a boolean --- .../HTTPResponseCompressor.swift | 34 ++++++++++++++++--- .../HTTPResponseCompressorTest.swift | 14 ++++---- 2 files changed, 36 insertions(+), 12 deletions(-) diff --git a/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift b/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift index 7477723e..ebdd9e1a 100644 --- a/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift +++ b/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift @@ -78,13 +78,37 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne /// /// - Parameter responseHeaders: The headers that will be used for the response. These can be modified as needed at this stage, to clean up any marker headers used to statelessly determine if compression should occur, and the new headers will be used when writing the response. Compression headers are not yet provided and should not be set; ``HTTPResponseCompressor`` will set them accordingly based on the result of this predicate. /// - Parameter isCompressionSupported: Set to `true` if the client requested compatible compression, and if the HTTP response supports it, otherwise `false`. - /// - Returns: Return `true` if the compressor should proceed to compress the response, or `false` if the response should not be compressed. + /// - Returns: Return ``CompressionIntent/compressIfPossible`` if the compressor should proceed to compress the response, or ``CompressionIntent/doNotCompress`` if the response should not be compressed. /// - /// - Note: Returning `true` when compression is not supported will not enable compression, and the modified headers will always be used. + /// - Note: Returning ``CompressionIntent/compressIfPossible`` is only a suggestion — when compression is not supported, the response will be returned as is along with any modified headers. public typealias ResponseCompressionPredicate = ( _ responseHeaders: inout HTTPResponseHead, _ isCompressionSupported: Bool - ) -> Bool + ) -> CompressionIntent + + /// A signal a ``ResponseCompressionPredicate`` returns to indicate if it intends for compression to be used or not when supported by HTTP. + public struct CompressionIntent: Sendable, Hashable { + /// The internal type ``CompressionIntent`` uses. + enum RawValue { + /// The response should be compressed if supported by the HTTP protocol. + case compressIfPossible + /// The response should not be compressed even if supported by the HTTP protocol. + case doNotCompress + } + + /// The raw value of the intent. + let rawValue: RawValue + + /// Initialize the raw value with an internal intent. + init(_ rawValue: RawValue) { + self.rawValue = rawValue + } + + /// The response should be compressed if supported by the HTTP protocol. + public static let compressIfPossible = CompressionIntent(.compressIfPossible) + /// The response should not be compressed even if supported by the HTTP protocol. + public static let doNotCompress = CompressionIntent(.doNotCompress) + } /// Errors which can occur when compressing public enum CompressionError: Error { @@ -152,10 +176,10 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne let requestSupportsCompression = algorithm != nil && responseHead.status.mayHaveResponseBody /// If a predicate was set, ask it if we should compress when compression is supported, and give the predicate a chance to clean up any marker headers that may have been set even if compression were not supported. - let predicateSupportsCompression = responseCompressionPredicate?(&responseHead, requestSupportsCompression) ?? true + let predicateCompressionIntent = responseCompressionPredicate?(&responseHead, requestSupportsCompression) ?? .compressIfPossible /// Make sure that compression should proceed, otherwise stop here and supply the response headers before configuring the compressor. - guard let algorithm, requestSupportsCompression, predicateSupportsCompression else { + guard let algorithm, requestSupportsCompression, predicateCompressionIntent == .compressIfPossible else { context.write(wrapOutboundOut(.head(responseHead)), promise: promise) return } diff --git a/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift b/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift index 343142cd..99a56730 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift @@ -742,7 +742,7 @@ class HTTPResponseCompressorTest: XCTestCase { defer { predicateWasCalled.fulfill() } XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) XCTAssertEqual(isCompressionSupported, true) - return true + return .compressIfPossible } let channel = try compressionChannel(compressor: compressor) @@ -770,7 +770,7 @@ class HTTPResponseCompressorTest: XCTestCase { defer { predicateWasCalled.fulfill() } XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) XCTAssertEqual(isCompressionSupported, false) - return true + return .compressIfPossible } let channel = try compressionChannel(compressor: compressor) @@ -798,7 +798,7 @@ class HTTPResponseCompressorTest: XCTestCase { XCTAssertEqual(responseHeaders.status, .notModified) XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) XCTAssertEqual(isCompressionSupported, false) - return true + return .compressIfPossible } let channel = EmbeddedChannel() @@ -836,7 +836,7 @@ class HTTPResponseCompressorTest: XCTestCase { defer { predicateWasCalled.fulfill() } XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) XCTAssertEqual(isCompressionSupported, true) - return false + return .doNotCompress } let channel = try compressionChannel(compressor: compressor) @@ -863,7 +863,7 @@ class HTTPResponseCompressorTest: XCTestCase { defer { predicateWasCalled.fulfill() } XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) XCTAssertEqual(isCompressionSupported, false) - return false + return .doNotCompress } let channel = try compressionChannel(compressor: compressor) @@ -891,7 +891,7 @@ class HTTPResponseCompressorTest: XCTestCase { XCTAssertEqual(responseHeaders.status, .notModified) XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) XCTAssertEqual(isCompressionSupported, false) - return false + return .doNotCompress } let channel = EmbeddedChannel() @@ -932,7 +932,7 @@ class HTTPResponseCompressorTest: XCTestCase { XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json", "X-Compression" : isEnabled ? "enable" : "disable"]) responseHeaders.headers.remove(name: "X-Compression") XCTAssertEqual(isCompressionSupported, true) - return isEnabled + return isEnabled ? .compressIfPossible : .doNotCompress } let channel = try compressionChannel(compressor: compressor)