From 1972e7764cc8a524ace79bebfc7a12caab481e2d Mon Sep 17 00:00:00 2001 From: Dimitri Bouniol Date: Tue, 9 Jul 2024 03:44:34 -0700 Subject: [PATCH] Added support for conditional response compression --- .../HTTPResponseCompressor.swift | 55 +++- .../HTTPResponseCompressorTest.swift | 290 +++++++++++++++++- 2 files changed, 325 insertions(+), 20 deletions(-) diff --git a/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift b/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift index 7b251792..255840ad 100644 --- a/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift +++ b/Sources/NIOHTTPCompression/HTTPResponseCompressor.swift @@ -57,6 +57,13 @@ private func qValueFromHeader(_ text: S) -> Float { /// ahead-of-time instead of dynamically, could be a waste of CPU time and latency for relatively minimal /// benefit. This channel handler should be present in the pipeline only for dynamically-generated and /// highly-compressible content, which will see the biggest benefits from streaming compression. +/// +/// The compressor optionally accepts a predicate to help it determine on a per-request basis if compression +/// should be used, even if the client requests it for the request. This could be used to conditionally and statelessly +/// enable compression based on resource types, or by emitting and checking for marker headers as needed. +/// Since the predicate is always called, it can also be used to clean up those marker headers if compression was +/// not actually supported for any reason (ie. the client didn't provide compatible `Accept` headers, or the +/// response was missing a body due to a special status code being used) public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChannelHandler { /// This class accepts `HTTPServerRequestPart` inbound public typealias InboundIn = HTTPServerRequestPart @@ -66,6 +73,18 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne public typealias OutboundIn = HTTPServerResponsePart /// This class emits `HTTPServerResponsePart` outbound. public typealias OutboundOut = HTTPServerResponsePart + + /// A closure that accepts a response header, optionally modifies it, and returns `true` if the response it belongs to should be compressed. + /// + /// - Parameter responseHeaders: The headers that will be used for the response. These can be modified as needed at this stage, to clean up any marker headers used to statelessly determine if compression should occur, and the new headers will be used when writing the response. Compression headers are not yet provided and should not be set; ``HTTPResponseCompressor`` will set them accordingly based on the result of this predicate. + /// - Parameter isCompressionSupported: Set to `true` if the client requested compatible compression, and if the HTTP response supports it, otherwise `false`. + /// - Returns: Return `true` if the compressor should proceed to compress the response, or `false` if the response should not be compressed. + /// + /// - Note: Returning `true` when compression is not supported will not enable compression, and the modified headers will always be used. + public typealias ResponseCompressionPredicate = ( + _ responseHeaders: inout HTTPResponseHead, + _ isCompressionSupported: Bool + ) -> Bool /// Errors which can occur when compressing public enum CompressionError: Error { @@ -84,11 +103,23 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne private var pendingWritePromise: EventLoopPromise! private let initialByteBufferCapacity: Int + private let responseCompressionPredicate: ResponseCompressionPredicate? - /// Initialise a ``HTTPResponseCompressor`` + /// Initialize a ``HTTPResponseCompressor``. + /// - Parameter initialByteBufferCapacity: Initial size of buffer to allocate when hander is first added. + @_disfavoredOverload + @available(*, deprecated, message: "Deprecated in favor of HTTPResponseCompressor(initialByteBufferCapacity:, responseCompressionPredicate:)") + public convenience init(initialByteBufferCapacity: Int = 1024) { + // TODO: This version is kept around for backwards compatibility and should be merged with the signature below in the next major version. + self.init(initialByteBufferCapacity: initialByteBufferCapacity, responseCompressionPredicate: nil) + } + + /// Initialize a ``HTTPResponseCompressor``. /// - Parameter initialByteBufferCapacity: Initial size of buffer to allocate when hander is first added. - public init(initialByteBufferCapacity: Int = 1024) { + /// - Parameter responseCompressionPredicate: The predicate used to determine if the response should be compressed or not based on its headers. Defaults to `nil`, which will compress every response this handler sees. This predicate is always called wether the client supports compression for this response or not, so it can be used to clean up any marker headers you may use to determine if compression should be performed or not. Please see ``ResponseCompressionPredicate`` for more details. + public init(initialByteBufferCapacity: Int = 1024, responseCompressionPredicate: ResponseCompressionPredicate? = nil) { self.initialByteBufferCapacity = initialByteBufferCapacity + self.responseCompressionPredicate = responseCompressionPredicate self.compressor = NIOCompression.Compressor() } @@ -118,17 +149,28 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne let httpData = unwrapOutboundIn(data) switch httpData { case .head(var responseHead): - guard let algorithm = compressionAlgorithm(), responseHead.status.mayHaveResponseBody else { + /// Grab the algorithm to use from the bottom of the accept queue, which will help determine if we support compression for this response or not. + let algorithm = compressionAlgorithm() + let requestSupportsCompression = algorithm != nil && responseHead.status.mayHaveResponseBody + + /// If a predicate was set, ask it if we should compress when compression is supported, and give the predicate a chance to clean up any marker headers that may have been set even if compression were not supported. + let predicateSupportsCompression = responseCompressionPredicate?(&responseHead, requestSupportsCompression) ?? true + + /// Make sure that compression should proceed, otherwise stop here and supply the response headers before configuring the compressor. + guard let algorithm, requestSupportsCompression, predicateSupportsCompression else { context.write(wrapOutboundOut(.head(responseHead)), promise: promise) return } - // Previous handlers in the pipeline might have already set this header even though - // they should not as it is compressor responsibility to decide what encoding to use + + /// Previous handlers in the pipeline might have already set this header even though they should not have as it is compressor responsibility to decide what encoding to use. responseHead.headers.replaceOrAdd(name: "Content-Encoding", value: algorithm.description) + + /// Initialize the compressor and write the header data, which marks the compressor as "active" allowing the `.body` and `.end` cases to properly compress the response rather than passing it as is. compressor.initialize(encoding: algorithm) pendingResponse.bufferResponseHead(responseHead) pendingWritePromise.futureResult.cascade(to: promise) case .body(let body): + /// We already determined if compression should occur based on the `.head` case above, so here we simply need to check if the compressor is active or not to determine if we should compress the body chunks or stream them as is. if compressor.isActive { pendingResponse.bufferBodyPart(body) pendingWritePromise.futureResult.cascade(to: promise) @@ -136,13 +178,12 @@ public final class HTTPResponseCompressor: ChannelDuplexHandler, RemovableChanne context.write(data, promise: promise) } case .end: - // This compress is not done in flush because we need to be done with the - // compressor now. guard compressor.isActive else { context.write(data, promise: promise) return } + /// Compress any trailers and finalize the response. Note that this compression stage is not done in `flush()` because we need to clean up the compressor state to be ready for the next response that can come in on the same handler. pendingResponse.bufferResponseEnd(httpData) pendingWritePromise.futureResult.cascade(to: promise) emitPendingWrites(context: context) diff --git a/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift b/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift index 44a2ef08..ffbb3184 100644 --- a/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift +++ b/Tests/NIOHTTPCompressionTests/HTTPResponseCompressorTest.swift @@ -234,10 +234,18 @@ class HTTPResponseCompressorTest: XCTestCase { XCTAssertEqual(expectedResponse, outputBuffer) } - private func assertDeflatedResponse(channel: EmbeddedChannel, writeStrategy: WriteStrategy = .once) throws { + private func assertDeflatedResponse( + channel: EmbeddedChannel, + writeStrategy: WriteStrategy = .once, + responseHeaders: HTTPHeaders = [:], + assertHeaders: HTTPHeaders? = nil + ) throws { let bodySize = 2048 - let response = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), - status: .ok) + let response = HTTPResponseHead( + version: HTTPVersion(major: 1, minor: 1), + status: .ok, + headers: responseHeaders + ) let body = [UInt8](repeating: 60, count: bodySize) var bodyBuffer = channel.allocator.buffer(capacity: bodySize) bodyBuffer.writeBytes(body) @@ -265,17 +273,28 @@ class HTTPResponseCompressorTest: XCTestCase { XCTAssertEqual(compressedResponse.headers[canonicalForm: "transfer-encoding"], ["chunked"]) } + if let assertHeaders { + XCTAssertEqual(compressedResponse.headers, assertHeaders) + } + assertDecompressedResponseMatches(responseData: &compressedBody, expectedResponse: bodyBuffer, allocator: channel.allocator, decompressor: z_stream.decompressDeflate) } - private func assertGzippedResponse(channel: EmbeddedChannel, writeStrategy: WriteStrategy = .once, additionalHeaders: HTTPHeaders = HTTPHeaders()) throws { + private func assertGzippedResponse( + channel: EmbeddedChannel, + writeStrategy: WriteStrategy = .once, + responseHeaders: HTTPHeaders = [:], + assertHeaders: HTTPHeaders? = nil + ) throws { let bodySize = 2048 - var response = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), - status: .ok) - response.headers = additionalHeaders + let response = HTTPResponseHead( + version: HTTPVersion(major: 1, minor: 1), + status: .ok, + headers: responseHeaders + ) let body = [UInt8](repeating: 60, count: bodySize) var bodyBuffer = channel.allocator.buffer(capacity: bodySize) bodyBuffer.writeBytes(body) @@ -303,16 +322,28 @@ class HTTPResponseCompressorTest: XCTestCase { XCTAssertEqual(compressedResponse.headers[canonicalForm: "transfer-encoding"], ["chunked"]) } + if let assertHeaders { + XCTAssertEqual(compressedResponse.headers, assertHeaders) + } + assertDecompressedResponseMatches(responseData: &compressedBody, expectedResponse: bodyBuffer, allocator: channel.allocator, decompressor: z_stream.decompressGzip) } - private func assertUncompressedResponse(channel: EmbeddedChannel, writeStrategy: WriteStrategy = .once) throws { + private func assertUncompressedResponse( + channel: EmbeddedChannel, + writeStrategy: WriteStrategy = .once, + responseHeaders: HTTPHeaders = [:], + assertHeaders: HTTPHeaders? = nil + ) throws { let bodySize = 2048 - let response = HTTPResponseHead(version: HTTPVersion(major: 1, minor: 1), - status: .ok) + let response = HTTPResponseHead( + version: HTTPVersion(major: 1, minor: 1), + status: .ok, + headers: responseHeaders + ) let body = [UInt8](repeating: 60, count: bodySize) var bodyBuffer = channel.allocator.buffer(capacity: bodySize) bodyBuffer.writeBytes(body) @@ -330,14 +361,19 @@ class HTTPResponseCompressorTest: XCTestCase { var compressedChunks = data.1 let uncompressedBody = compressedChunks[0].merge(compressedChunks[1...]) XCTAssertEqual(compressedResponse.headers[canonicalForm: "content-encoding"], []) + if let assertHeaders { + XCTAssertEqual(compressedResponse.headers, assertHeaders) + } XCTAssertEqual(uncompressedBody.readableBytes, 2048) XCTAssertEqual(uncompressedBody, bodyBuffer) } - private func compressionChannel() throws -> EmbeddedChannel { + private func compressionChannel( + compressor: HTTPResponseCompressor = HTTPResponseCompressor() + ) throws -> EmbeddedChannel { let channel = EmbeddedChannel() XCTAssertNoThrow(try channel.pipeline.addHandler(HTTPResponseEncoder(), name: "encoder").wait()) - XCTAssertNoThrow(try channel.pipeline.addHandler(HTTPResponseCompressor(), name: "compressor").wait()) + XCTAssertNoThrow(try channel.pipeline.addHandler(compressor, name: "compressor").wait()) return channel } @@ -504,7 +540,7 @@ class HTTPResponseCompressorTest: XCTestCase { } try sendRequest(acceptEncoding: "deflate;q=2.2, gzip;q=0.3", channel: channel) - try assertGzippedResponse(channel: channel, additionalHeaders: HTTPHeaders([("Content-Encoding", "deflate")])) + try assertGzippedResponse(channel: channel, responseHeaders: HTTPHeaders([("Content-Encoding", "deflate")])) } func testRemovingHandlerFailsPendingWrites() throws { @@ -688,6 +724,234 @@ class HTTPResponseCompressorTest: XCTestCase { } } } + + func testConditionalCompressionEnabled() throws { + let predicateWasCalled = expectation(description: "Predicate was called") + let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in + predicateWasCalled.fulfill() + XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(isCompressionSupported, true) + return true + } + + let channel = try compressionChannel(compressor: compressor) + defer { + XCTAssertNoThrow(try channel.finish()) + } + + try sendRequest(acceptEncoding: "deflate", channel: channel) + try assertDeflatedResponse( + channel: channel, + responseHeaders: ["Content-Type" : "json"], + assertHeaders: [ + "Content-Type" : "json", + "Content-Encoding" : "deflate", + "Content-Length" : "23", + ] + ) + + waitForExpectations(timeout: 0) + } + + func testUnsupportedRequestConditionalCompressionEnabled() throws { + let predicateWasCalled = expectation(description: "Predicate was called") + let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in + predicateWasCalled.fulfill() + XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(isCompressionSupported, false) + return true + } + + let channel = try compressionChannel(compressor: compressor) + defer { + XCTAssertNoThrow(try channel.finish()) + } + + try sendRequest(acceptEncoding: nil, channel: channel) + try assertUncompressedResponse( + channel: channel, + responseHeaders: ["Content-Type" : "json"], + assertHeaders: [ + "Content-Type" : "json", + "transfer-encoding" : "chunked", + ] + ) + + waitForExpectations(timeout: 0) + } + + func testUnsupportedStatusConditionalCompressionEnabled() throws { + let predicateWasCalled = expectation(description: "Predicate was called") + let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in + predicateWasCalled.fulfill() + XCTAssertEqual(responseHeaders.status, .notModified) + XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(isCompressionSupported, false) + return true + } + + let channel = EmbeddedChannel() + XCTAssertNoThrow(try channel.pipeline.addHandler(compressor).wait()) + defer { + XCTAssertNoThrow(try channel.finish()) + } + + try sendRequest(acceptEncoding: "deflate", channel: channel) + + let head = HTTPResponseHead( + version: .init(major: 1, minor: 1), + status: .notModified, + headers: ["Content-Type" : "json"] + ) + try channel.writeOutbound(HTTPServerResponsePart.head(head)) + try channel.writeOutbound(HTTPServerResponsePart.end(nil)) + + while let part = try channel.readOutbound(as: HTTPServerResponsePart.self) { + switch part { + case .head(let head): + XCTAssertEqual(head.headers[canonicalForm: "content-encoding"], []) + case .body: + XCTFail("Unexpected body") + case .end: break + } + } + + waitForExpectations(timeout: 0) + } + + func testConditionalCompressionDisabled() throws { + let predicateWasCalled = expectation(description: "Predicate was called") + let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in + predicateWasCalled.fulfill() + XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(isCompressionSupported, true) + return false + } + + let channel = try compressionChannel(compressor: compressor) + defer { + XCTAssertNoThrow(try channel.finish()) + } + + try sendRequest(acceptEncoding: "deflate", channel: channel) + try assertUncompressedResponse( + channel: channel, + responseHeaders: ["Content-Type" : "json"], + assertHeaders: [ + "Content-Type" : "json", + "transfer-encoding" : "chunked", + ] + ) + + waitForExpectations(timeout: 0) + } + + func testUnsupportedRequestConditionalCompressionDisabled() throws { + let predicateWasCalled = expectation(description: "Predicate was called") + let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in + predicateWasCalled.fulfill() + XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(isCompressionSupported, false) + return false + } + + let channel = try compressionChannel(compressor: compressor) + defer { + XCTAssertNoThrow(try channel.finish()) + } + + try sendRequest(acceptEncoding: nil, channel: channel) + try assertUncompressedResponse( + channel: channel, + responseHeaders: ["Content-Type" : "json"], + assertHeaders: [ + "Content-Type" : "json", + "transfer-encoding" : "chunked", + ] + ) + + waitForExpectations(timeout: 0) + } + + func testUnsupportedStatusConditionalCompressionDisabled() throws { + let predicateWasCalled = expectation(description: "Predicate was called") + let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in + predicateWasCalled.fulfill() + XCTAssertEqual(responseHeaders.status, .notModified) + XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json"]) + XCTAssertEqual(isCompressionSupported, false) + return false + } + + let channel = EmbeddedChannel() + XCTAssertNoThrow(try channel.pipeline.addHandler(compressor).wait()) + defer { + XCTAssertNoThrow(try channel.finish()) + } + + try sendRequest(acceptEncoding: "deflate", channel: channel) + + let head = HTTPResponseHead( + version: .init(major: 1, minor: 1), + status: .notModified, + headers: ["Content-Type" : "json"] + ) + try channel.writeOutbound(HTTPServerResponsePart.head(head)) + try channel.writeOutbound(HTTPServerResponsePart.end(nil)) + + while let part = try channel.readOutbound(as: HTTPServerResponsePart.self) { + switch part { + case .head(let head): + XCTAssertEqual(head.headers[canonicalForm: "content-encoding"], []) + case .body: + XCTFail("Unexpected body") + case .end: break + } + } + + waitForExpectations(timeout: 0) + } + + func testConditionalCompressionModifiedHeaders() throws { + let predicateWasCalled = expectation(description: "Predicate was called") + predicateWasCalled.expectedFulfillmentCount = 2 + let compressor = HTTPResponseCompressor { responseHeaders, isCompressionSupported in + predicateWasCalled.fulfill() + let isEnabled = responseHeaders.headers[canonicalForm: "x-compression"].first == "enable" + XCTAssertEqual(responseHeaders.headers, ["Content-Type" : "json", "X-Compression" : isEnabled ? "enable" : "disable"]) + responseHeaders.headers.remove(name: "X-Compression") + XCTAssertEqual(isCompressionSupported, true) + return isEnabled + } + + let channel = try compressionChannel(compressor: compressor) + defer { + XCTAssertNoThrow(try channel.finish()) + } + + try sendRequest(acceptEncoding: "deflate", channel: channel) + try assertDeflatedResponse( + channel: channel, + responseHeaders: ["Content-Type" : "json", "X-Compression" : "enable"], + assertHeaders: [ + "Content-Type" : "json", + "Content-Encoding" : "deflate", + "Content-Length" : "23", + ] + ) + + try sendRequest(acceptEncoding: "deflate", channel: channel) + try assertUncompressedResponse( + channel: channel, + responseHeaders: ["Content-Type" : "json", "X-Compression" : "disable"], + assertHeaders: [ + "Content-Type" : "json", + "transfer-encoding" : "chunked", + ] + ) + + waitForExpectations(timeout: 0) + } } extension EventLoopFuture {