Skip to content

Commit

Permalink
extract common classes for server request decompressor (#60)
Browse files Browse the repository at this point in the history
* extract common classes for server request decompressor

* review fix: make fields private and make state part of the handler

* review fixes

* review fix: reserve capacity before inflating
  • Loading branch information
artemredkin authored and Lukasa committed Oct 8, 2019
1 parent 863c6b5 commit 16fbdf3
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 130 deletions.
149 changes: 149 additions & 0 deletions Sources/NIOHTTPCompression/HTTPDecompression.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2019 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//

import CNIOExtrasZlib
import NIO

public enum NIOHTTPDecompression {
/// Specifies how to limit decompression inflation.
public struct DecompressionLimit {
private enum Limit {
case none
case size(Int)
case ratio(Int)
}

private var limit: Limit

/// No limit will be set.
public static let none = DecompressionLimit(limit: .none)
/// Limit will be set on the request body size.
public static func size(_ value: Int) -> DecompressionLimit { return DecompressionLimit(limit: .size(value)) }
/// Limit will be set on a ratio between compressed body size and decompressed result.
public static func ratio(_ value: Int) -> DecompressionLimit { return DecompressionLimit(limit: .ratio(value)) }

func exceeded(compressed: Int, decompressed: Int) -> Bool {
switch self.limit {
case .none:
return false
case .size(let allowed):
return compressed > allowed
case .ratio(let ratio):
return decompressed > compressed * ratio
}
}
}

public enum DecompressionError: Error {
case limit
case inflationError(Int)
case initializationError(Int)
}

enum CompressionAlgorithm: String {
case gzip
case deflate

init?(header: String?) {
switch header {
case .some("gzip"):
self = .gzip
case .some("deflate"):
self = .deflate
default:
return nil
}
}

var window: CInt {
switch self {
case .deflate:
return 15
case .gzip:
return 15 + 16
}
}
}

struct Decompressor {
private let limit: NIOHTTPDecompression.DecompressionLimit
private var stream = z_stream()
private var inflated = 0

init(limit: NIOHTTPDecompression.DecompressionLimit) {
self.limit = limit
}

mutating func decompress(part: inout ByteBuffer, buffer: inout ByteBuffer, originalLength: Int) throws {
buffer.reserveCapacity(part.readableBytes * 2)

self.inflated += try self.stream.inflatePart(input: &part, output: &buffer)

if self.limit.exceeded(compressed: originalLength, decompressed: self.inflated) {
throw NIOHTTPDecompression.DecompressionError.limit
}
}

mutating func initializeDecoder(encoding: NIOHTTPDecompression.CompressionAlgorithm, length: Int) throws {
self.stream.zalloc = nil
self.stream.zfree = nil
self.stream.opaque = nil

let rc = CNIOExtrasZlib_inflateInit2(&self.stream, encoding.window)
guard rc == Z_OK else {
throw NIOHTTPDecompression.DecompressionError.initializationError(Int(rc))
}
}

mutating func deinitializeDecoder() {
inflateEnd(&self.stream)
}
}
}

extension z_stream {
mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws -> Int {
var written = 0
try input.readWithUnsafeMutableReadableBytes { pointer in
self.avail_in = UInt32(pointer.count)
self.next_in = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)

defer {
self.avail_in = 0
self.next_in = nil
self.avail_out = 0
self.next_out = nil
}

written += try self.inflatePart(to: &output)

return pointer.count - Int(self.avail_in)
}
return written
}

private mutating func inflatePart(to buffer: inout ByteBuffer) throws -> Int {
return try buffer.writeWithUnsafeMutableBytes { pointer in
self.avail_out = UInt32(pointer.count)
self.next_out = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)

let rc = inflate(&self, Z_NO_FLUSH)
guard rc == Z_OK || rc == Z_STREAM_END else {
throw NIOHTTPDecompression.DecompressionError.inflationError(Int(rc))
}

return pointer.count - Int(self.avail_out)
}
}
}
139 changes: 12 additions & 127 deletions Sources/NIOHTTPCompression/HTTPResponseDecompressor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
//
//===----------------------------------------------------------------------===//

import CNIOExtrasZlib
import NIO
import NIOHTTP1

Expand All @@ -22,78 +21,16 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC
public typealias OutboundIn = HTTPClientRequestPart
public typealias OutboundOut = HTTPClientRequestPart

/// Specifies how to limit decompression inflation.
public struct DecompressionLimit {
enum Limit {
case none
case size(Int)
case ratio(Int)
}

var limit: Limit

/// No limit will be set.
public static let none = DecompressionLimit(limit: .none)
/// Limit will be set on the request body size.
public static func size(_ value: Int) -> DecompressionLimit { return DecompressionLimit(limit: .size(value)) }
/// Limit will be set on a ratio between compressed body size and decompressed result.
public static func ratio(_ value: Int) -> DecompressionLimit { return DecompressionLimit(limit: .ratio(value)) }

func exceeded(compressed: Int, decompressed: Int) -> Bool {
switch self.limit {
case .none:
return false
case .size(let allowed):
return compressed > allowed
case .ratio(let ratio):
return decompressed > compressed * ratio
}
}
}

public enum DecompressionError: Error {
case limit
case inflationError(Int)
case initializationError(Int)
}

private enum CompressionAlgorithm: String {
case gzip
case deflate

init?(header: String?) {
switch header {
case .some("gzip"):
self = .gzip
case .some("deflate"):
self = .deflate
default:
return nil
}
}

var window: Int32 {
switch self {
case .deflate:
return 15
case .gzip:
return 15 + 16
}
}
}

private enum State {
case empty
case compressed(CompressionAlgorithm, Int)
case compressed(NIOHTTPDecompression.CompressionAlgorithm, Int)
}

private let limit: DecompressionLimit
private var state = State.empty
private var stream = z_stream()
private var inflated = 0
private var decompressor: NIOHTTPDecompression.Decompressor

public init(limit: DecompressionLimit) {
self.limit = limit
public init(limit: NIOHTTPDecompression.DecompressionLimit) {
self.decompressor = NIOHTTPDecompression.Decompressor(limit: limit)
}

public func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
Expand All @@ -115,13 +52,14 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC
switch self.unwrapInboundIn(data) {
case .head(let head):
let contentType = head.headers[canonicalForm: "Content-Encoding"].first?.lowercased()
let algorithm = CompressionAlgorithm(header: contentType)
let algorithm = NIOHTTPDecompression.CompressionAlgorithm(header: contentType)

let length = head.headers[canonicalForm: "Content-Length"].first.flatMap { Int($0) }

if let algorithm = algorithm, let length = length {
do {
try self.initializeDecoder(encoding: algorithm, length: length)
self.state = .compressed(algorithm, length)
try self.decompressor.initializeDecoder(encoding: algorithm, length: length)
} catch {
context.fireErrorCaught(error)
return
Expand All @@ -133,80 +71,27 @@ public final class NIOHTTPResponseDecompressor: ChannelDuplexHandler, RemovableC
switch self.state {
case .compressed(_, let originalLength):
while part.readableBytes > 0 {
var buffer = context.channel.allocator.buffer(capacity: 16384)
do {
var buffer = context.channel.allocator.buffer(capacity: 16384)
try self.stream.inflatePart(input: &part, output: &buffer)
self.inflated += buffer.readableBytes

if self.limit.exceeded(compressed: originalLength, decompressed: self.inflated) {
context.fireErrorCaught(DecompressionError.limit)
return
}

context.fireChannelRead(self.wrapInboundOut(.body(buffer)))
try self.decompressor.decompress(part: &part, buffer: &buffer, originalLength: originalLength)
} catch {
context.fireErrorCaught(error)
return
}

context.fireChannelRead(self.wrapInboundOut(.body(buffer)))
}
default:
context.fireChannelRead(data)
}
case .end:
switch self.state {
case .compressed:
inflateEnd(&self.stream)
self.decompressor.deinitializeDecoder()
default:
break
}
context.fireChannelRead(data)
}
}

private func initializeDecoder(encoding: CompressionAlgorithm, length: Int) throws {
self.state = .compressed(encoding, length)

self.stream.zalloc = nil
self.stream.zfree = nil
self.stream.opaque = nil

let rc = CNIOExtrasZlib_inflateInit2(&self.stream, encoding.window)
guard rc == Z_OK else {
throw DecompressionError.initializationError(Int(rc))
}
}
}

extension z_stream {
mutating func inflatePart(input: inout ByteBuffer, output: inout ByteBuffer) throws {
try input.readWithUnsafeMutableReadableBytes { pointer in
self.avail_in = UInt32(pointer.count)
self.next_in = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)

defer {
self.avail_in = 0
self.next_in = nil
self.avail_out = 0
self.next_out = nil
}

try self.inflatePart(to: &output)

return pointer.count - Int(self.avail_in)
}
}

private mutating func inflatePart(to buffer: inout ByteBuffer) throws {
try buffer.writeWithUnsafeMutableBytes { pointer in
self.avail_out = UInt32(pointer.count)
self.next_out = CNIOExtrasZlib_voidPtr_to_BytefPtr(pointer.baseAddress!)

let rc = inflate(&self, Z_NO_FLUSH)
guard rc == Z_OK || rc == Z_STREAM_END else {
throw NIOHTTPResponseDecompressor.DecompressionError.inflationError(Int(rc))
}

return pointer.count - Int(self.avail_out)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class HTTPResponseDecompressorTest: XCTestCase {
let body = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17])
do {
try channel.writeInbound(HTTPClientResponsePart.body(body))
} catch let error as NIOHTTPResponseDecompressor.DecompressionError {
} catch let error as NIOHTTPDecompression.DecompressionError {
switch error {
case .limit:
// ok
Expand All @@ -63,7 +63,7 @@ class HTTPResponseDecompressorTest: XCTestCase {
let body = ByteBuffer.of(bytes: [120, 156, 75, 76, 28, 5, 200, 0, 0, 248, 66, 103, 17])
do {
try channel.writeInbound(HTTPClientResponsePart.body(body))
} catch let error as NIOHTTPResponseDecompressor.DecompressionError {
} catch let error as NIOHTTPDecompression.DecompressionError {
switch error {
case .limit:
// ok
Expand Down Expand Up @@ -100,7 +100,7 @@ class HTTPResponseDecompressorTest: XCTestCase {

do {
try channel.writeInbound(HTTPClientResponsePart.body(compressed))
} catch let error as NIOHTTPResponseDecompressor.DecompressionError {
} catch let error as NIOHTTPDecompression.DecompressionError {
switch error {
case .limit:
// ok
Expand Down

0 comments on commit 16fbdf3

Please sign in to comment.