From 3de0c884a4f223c4898235355795743f5f93dbaa Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Wed, 18 Nov 2020 10:07:55 +0200 Subject: [PATCH 1/3] Drop support for Python 2 --- .github/workflows/ci.yml | 2 -- CONTRIBUTING.md | 6 ++--- README.rst | 3 ++- bench/asv.conf.json | 2 +- docs/source/index.rst | 4 +++- fuzz/afl-server.py | 7 +----- h11/_connection.py | 2 +- h11/_events.py | 5 +---- h11/_headers.py | 2 +- h11/_readers.py | 16 ++++---------- h11/_receivebuffer.py | 8 +------ h11/_state.py | 2 +- h11/_util.py | 28 ++++------------------- h11/_writers.py | 32 +++++---------------------- h11/tests/test_against_stdlib_http.py | 18 +++------------ h11/tests/test_events.py | 6 ++--- h11/tests/test_util.py | 2 +- newsfragments/114.removal.rst | 2 ++ setup.cfg | 3 --- setup.py | 4 ++-- tox.ini | 4 +--- 21 files changed, 40 insertions(+), 118 deletions(-) create mode 100644 newsfragments/114.removal.rst diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 1125da4..c31669a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,11 +13,9 @@ jobs: max-parallel: 5 matrix: python-version: - - 2.7 - 3.6 - 3.7 - 3.8 - - pypy2 - pypy3 steps: diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index b9eccd9..45a1ebf 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -70,10 +70,10 @@ other hand, the following are all very welcome: tox ``` - But note that: (1) this will print slightly misleading coverage + But note that: (1) this might print slightly misleading coverage statistics, because it only shows coverage for individual python - versions, and there are some lines that are only executed on python - 2 or only executed on python 3, and (2) the full test suite will + versions, and there might be some lines that are only executed on some + python versions or implementations, and (2) the full test suite will automatically get run when you submit a pull request, so you don't need to worry too much about tracking down a version of cpython 3.3 or whatever just to run the tests. diff --git a/README.rst b/README.rst index 74bb182..f998b01 100644 --- a/README.rst +++ b/README.rst @@ -112,7 +112,8 @@ library. It has a test suite with 100.0% coverage for both statements and branches. -Currently it supports Python 3 (testing on 3.5-3.8), Python 2.7, and PyPy. +Currently it supports Python 3 (testing on 3.5-3.8) and PyPy 3. +The last Python 2-compatible version was h11 0.11.x. (Originally it had a Cython wrapper for `http-parser `_ and a beautiful nested state machine implemented with ``yield from`` to postprocess the output. But diff --git a/bench/asv.conf.json b/bench/asv.conf.json index f65e4dd..0a07c42 100644 --- a/bench/asv.conf.json +++ b/bench/asv.conf.json @@ -36,7 +36,7 @@ // The Pythons you'd like to test against. If not provided, defaults // to the current version of Python used to run `asv`. - "pythons": ["2.7", "3.5", "pypy"], + "pythons": ["3.8", "pypy3"], // The matrix of dependencies to test. Each key is the name of a // package (in PyPI) and the values are version numbers. An empty diff --git a/docs/source/index.rst b/docs/source/index.rst index c08b76d..617638f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -44,7 +44,9 @@ whatever. But h11 makes it much easier to implement something like Vital statistics ---------------- -* Requirements: Python 2.7 or Python 3.5+ (PyPy works great) +* Requirements: Python 3.5+ (PyPy works great) + + The last Python 2-compatible version was h11 0.11.x. * Install: ``pip install h11`` diff --git a/fuzz/afl-server.py b/fuzz/afl-server.py index 450c68b..0ff1947 100644 --- a/fuzz/afl-server.py +++ b/fuzz/afl-server.py @@ -9,11 +9,6 @@ import h11 -if sys.version_info[0] >= 3: - in_file = sys.stdin.detach() -else: - in_file = sys.stdin - def process_all(c): while True: @@ -26,7 +21,7 @@ def process_all(c): afl.init() -data = in_file.read() +data = sys.stdin.detach().read() # one big chunk server1 = h11.Connection(h11.SERVER) diff --git a/h11/_connection.py b/h11/_connection.py index 410c4e9..b6f8760 100644 --- a/h11/_connection.py +++ b/h11/_connection.py @@ -109,7 +109,7 @@ def _body_framing(request_method, event): ################################################################ -class Connection(object): +class Connection: """An object encapsulating the state of an HTTP connection. Args: diff --git a/h11/_events.py b/h11/_events.py index c11d838..1827930 100644 --- a/h11/_events.py +++ b/h11/_events.py @@ -24,7 +24,7 @@ request_target_re = re.compile(request_target.encode("ascii")) -class _EventBundle(object): +class _EventBundle: _fields = [] _defaults = {} @@ -85,9 +85,6 @@ def __repr__(self): def __eq__(self, other): return self.__class__ == other.__class__ and self.__dict__ == other.__dict__ - def __ne__(self, other): - return not self.__eq__(other) - # This is an unhashable type. __hash__ = None diff --git a/h11/_headers.py b/h11/_headers.py index 5229ac4..7ed39bc 100644 --- a/h11/_headers.py +++ b/h11/_headers.py @@ -132,7 +132,7 @@ def normalize_and_validate(headers, _parsed=False): raw_name = name name = name.lower() if name == b"content-length": - lengths = set(length.strip() for length in value.split(b",")) + lengths = {length.strip() for length in value.split(b",")} if len(lengths) != 1: raise LocalProtocolError("conflicting Content-Length headers") value = lengths.pop() diff --git a/h11/_readers.py b/h11/_readers.py index cc86bff..75f00bc 100644 --- a/h11/_readers.py +++ b/h11/_readers.py @@ -54,13 +54,7 @@ def _obsolete_line_fold(lines): def _decode_header_lines(lines): for line in _obsolete_line_fold(lines): - # _obsolete_line_fold yields either bytearray or bytes objects. On - # Python 3, validate() takes either and returns matches as bytes. But - # on Python 2, validate can return matches as bytearrays, so we have - # to explicitly cast back. - matches = validate( - header_field_re, bytes(line), "illegal header line: {!r}", bytes(line) - ) + matches = validate(header_field_re, line, "illegal header line: {!r}", line) yield (matches["field_name"], matches["field_value"]) @@ -127,7 +121,7 @@ def read_eof(self): chunk_header_re = re.compile(chunk_header.encode("ascii")) -class ChunkedReader(object): +class ChunkedReader: def __init__(self): self._bytes_in_chunk = 0 # After reading a chunk, we have to throw away the trailing \r\n; if @@ -163,9 +157,7 @@ def __call__(self, buf): chunk_header, ) # XX FIXME: we discard chunk extensions. Does anyone care? - # We convert to bytes because Python 2's `int()` function doesn't - # work properly on bytearray objects. - self._bytes_in_chunk = int(bytes(matches["chunk_size"]), base=16) + self._bytes_in_chunk = int(matches["chunk_size"], base=16) if self._bytes_in_chunk == 0: self._reading_trailer = True return self(buf) @@ -191,7 +183,7 @@ def read_eof(self): ) -class Http10Reader(object): +class Http10Reader: def __call__(self, buf): data = buf.maybe_extract_at_most(999999999) if data is None: diff --git a/h11/_receivebuffer.py b/h11/_receivebuffer.py index c56749a..8b709df 100644 --- a/h11/_receivebuffer.py +++ b/h11/_receivebuffer.py @@ -1,5 +1,3 @@ -import sys - __all__ = ["ReceiveBuffer"] @@ -38,7 +36,7 @@ # slightly clever thing where we delay calling compress() until we've # processed a whole event, which could in theory be slightly more efficient # than the internal bytearray support.) -class ReceiveBuffer(object): +class ReceiveBuffer: def __init__(self): self._data = bytearray() # These are both absolute offsets into self._data: @@ -53,10 +51,6 @@ def __bool__(self): def __bytes__(self): return bytes(self._data[self._start :]) - if sys.version_info[0] < 3: # version specific: Python 2 - __str__ = __bytes__ - __nonzero__ = __bool__ - def __len__(self): return len(self._data) - self._start diff --git a/h11/_state.py b/h11/_state.py index 70a5e04..0f08a09 100644 --- a/h11/_state.py +++ b/h11/_state.py @@ -197,7 +197,7 @@ } -class ConnectionState(object): +class ConnectionState: def __init__(self): # Extra bits of state that don't quite fit into the state model. diff --git a/h11/_util.py b/h11/_util.py index 0a2c28e..eb1a5cd 100644 --- a/h11/_util.py +++ b/h11/_util.py @@ -1,6 +1,3 @@ -import re -import sys - __all__ = [ "ProtocolError", "LocalProtocolError", @@ -74,34 +71,17 @@ def _reraise_as_remote_protocol_error(self): # (exc_info[0]) separately from the exception object (exc_info[1]), # and we only modified the latter. So we really do need to re-raise # the new type explicitly. - if sys.version_info[0] >= 3: - # On py3, the traceback is part of the exception object, so our - # in-place modification preserved it and we can just re-raise: - raise self - else: - # On py2, preserving the traceback requires 3-argument - # raise... but on py3 this is a syntax error, so we have to hide - # it inside an exec - exec("raise RemoteProtocolError, self, sys.exc_info()[2]") + # On py3, the traceback is part of the exception object, so our + # in-place modification preserved it and we can just re-raise: + raise self class RemoteProtocolError(ProtocolError): pass -try: - _fullmatch = type(re.compile("")).fullmatch -except AttributeError: - - def _fullmatch(regex, data): # version specific: Python < 3.4 - match = regex.match(data) - if match and match.end() != len(data): - match = None - return match - - def validate(regex, data, msg="malformed data", *format_args): - match = _fullmatch(regex, data) + match = regex.fullmatch(data) if not match: if format_args: msg = msg.format(*format_args) diff --git a/h11/_writers.py b/h11/_writers.py index 7531579..cb5e8a8 100644 --- a/h11/_writers.py +++ b/h11/_writers.py @@ -7,32 +7,12 @@ # - a writer # - or, for body writers, a dict of framin-dependent writer factories -import sys - from ._events import Data, EndOfMessage from ._state import CLIENT, IDLE, SEND_BODY, SEND_RESPONSE, SERVER from ._util import LocalProtocolError __all__ = ["WRITERS"] -# Equivalent of bstr % values, that works on python 3.x for x < 5 -if (3, 0) <= sys.version_info < (3, 5): - - def bytesmod(bstr, values): - decoded_values = [] - for value in values: - if isinstance(value, bytes): - decoded_values.append(value.decode("ascii")) - else: - decoded_values.append(value) - return (bstr.decode("ascii") % tuple(decoded_values)).encode("ascii") - - -else: - - def bytesmod(bstr, values): - return bstr % values - def write_headers(headers, write): # "Since the Host field-value is critical information for handling a @@ -41,17 +21,17 @@ def write_headers(headers, write): raw_items = headers._full_items for raw_name, name, value in raw_items: if name == b"host": - write(bytesmod(b"%s: %s\r\n", (raw_name, value))) + write(b"%s: %s\r\n" % (raw_name, value)) for raw_name, name, value in raw_items: if name != b"host": - write(bytesmod(b"%s: %s\r\n", (raw_name, value))) + write(b"%s: %s\r\n" % (raw_name, value)) write(b"\r\n") def write_request(request, write): if request.http_version != b"1.1": raise LocalProtocolError("I only send HTTP/1.1") - write(bytesmod(b"%s %s HTTP/1.1\r\n", (request.method, request.target))) + write(b"%s %s HTTP/1.1\r\n" % (request.method, request.target)) write_headers(request.headers, write) @@ -68,11 +48,11 @@ def write_any_response(response, write): # from stdlib's http.HTTPStatus table. Or maybe just steal their enums # (either by import or copy/paste). We already accept them as status codes # since they're of type IntEnum < int. - write(bytesmod(b"HTTP/1.1 %s %s\r\n", (status_bytes, response.reason))) + write(b"HTTP/1.1 %s %s\r\n" % (status_bytes, response.reason)) write_headers(response.headers, write) -class BodyWriter(object): +class BodyWriter: def __call__(self, event, write): if type(event) is Data: self.send_data(event.data, write) @@ -111,7 +91,7 @@ def send_data(self, data, write): # end-of-message. if not data: return - write(bytesmod(b"%x\r\n", (len(data),))) + write(b"%x\r\n" % len(data)) write(data) write(b"\r\n") diff --git a/h11/tests/test_against_stdlib_http.py b/h11/tests/test_against_stdlib_http.py index b4219ff..e6c5db4 100644 --- a/h11/tests/test_against_stdlib_http.py +++ b/h11/tests/test_against_stdlib_http.py @@ -1,26 +1,14 @@ import json import os.path import socket +import socketserver import threading from contextlib import closing, contextmanager +from http.server import SimpleHTTPRequestHandler +from urllib.request import urlopen import h11 -try: - from urllib.request import urlopen -except ImportError: # version specific: Python 2 - from urllib2 import urlopen - -try: - import socketserver -except ImportError: # version specific: Python 2 - import SocketServer as socketserver - -try: - from http.server import SimpleHTTPRequestHandler -except ImportError: # version specific: Python 2 - from SimpleHTTPServer import SimpleHTTPRequestHandler - @contextmanager def socket_server(handler): diff --git a/h11/tests/test_events.py b/h11/tests/test_events.py index 07ffc13..e20f741 100644 --- a/h11/tests/test_events.py +++ b/h11/tests/test_events.py @@ -1,3 +1,5 @@ +from http import HTTPStatus + import pytest from .. import _events @@ -154,10 +156,6 @@ def test_events(): def test_intenum_status_code(): # https://github.com/python-hyper/h11/issues/72 - try: - from http import HTTPStatus - except ImportError: - pytest.skip("Only affects Python 3") r = Response(status_code=HTTPStatus.OK, headers=[], http_version="1.0") assert r.status_code == HTTPStatus.OK diff --git a/h11/tests/test_util.py b/h11/tests/test_util.py index 74ab33b..d851bdc 100644 --- a/h11/tests/test_util.py +++ b/h11/tests/test_util.py @@ -93,7 +93,7 @@ def test_bytesify(): assert bytesify("123") == b"123" with pytest.raises(UnicodeEncodeError): - bytesify(u"\u1234") + bytesify("\u1234") with pytest.raises(TypeError): bytesify(10) diff --git a/newsfragments/114.removal.rst b/newsfragments/114.removal.rst new file mode 100644 index 0000000..849b82c --- /dev/null +++ b/newsfragments/114.removal.rst @@ -0,0 +1,2 @@ +Python 2.7 and PyPy 2 support is removed. h11 now requires Python>=3.5 including PyPy 3. +Users running `pip install h11` on Python 2 will automatically get the last Python 2-compatible version. diff --git a/setup.cfg b/setup.cfg index bda6834..0bd1262 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,3 @@ -[bdist_wheel] -universal=1 - [isort] combine_as_imports=True force_grid_wrap=0 diff --git a/setup.py b/setup.py index 25cbbe8..024b9a3 100644 --- a/setup.py +++ b/setup.py @@ -17,15 +17,15 @@ # This means, just install *everything* you see under h11/, even if it # doesn't look like a source file, so long as it appears in MANIFEST.in: include_package_data=True, + python_requires=">=3.5", classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", - "Programming Language :: Python :: 2", - "Programming Language :: Python :: 2.7", "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.5", "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", diff --git a/tox.ini b/tox.ini index de9b566..c2c748e 100644 --- a/tox.ini +++ b/tox.ini @@ -1,13 +1,11 @@ [tox] -envlist = format, py27, py36, py37, py38, pypy, pypy3 +envlist = format, py36, py37, py38, pypy3 [gh-actions] python = - 2.7: py27 3.6: py36 3.7: py37 3.8: py38, format - pypy2: pypy pypy3: pypy3 [testenv] From 868d214e6885a9e316ef3ef3ca0dcf77293f087f Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Wed, 18 Nov 2020 15:44:42 +0200 Subject: [PATCH 2/3] Get rid of _EventBundle _EventBundle uses a lot of dynamic python features to save on some duplication, but it slows things down, and will also make it much harder to add static typing. Since these types are now pretty stable, it seems not worth it. On the bench/ micro-benchmark: Before: 9322.6 requests/sec After : 10544.6 requests/sec --- h11/_events.py | 270 +++++++++++++++++++++++++++------------ h11/tests/test_events.py | 56 ++------ 2 files changed, 196 insertions(+), 130 deletions(-) diff --git a/h11/_events.py b/h11/_events.py index 1827930..a4a5c5a 100644 --- a/h11/_events.py +++ b/h11/_events.py @@ -24,72 +24,7 @@ request_target_re = re.compile(request_target.encode("ascii")) -class _EventBundle: - _fields = [] - _defaults = {} - - def __init__(self, **kwargs): - _parsed = kwargs.pop("_parsed", False) - allowed = set(self._fields) - for kwarg in kwargs: - if kwarg not in allowed: - raise TypeError( - "unrecognized kwarg {} for {}".format( - kwarg, self.__class__.__name__ - ) - ) - required = allowed.difference(self._defaults) - for field in required: - if field not in kwargs: - raise TypeError( - "missing required kwarg {} for {}".format( - field, self.__class__.__name__ - ) - ) - self.__dict__.update(self._defaults) - self.__dict__.update(kwargs) - - # Special handling for some fields - - if "headers" in self.__dict__: - self.headers = _headers.normalize_and_validate( - self.headers, _parsed=_parsed - ) - - if not _parsed: - for field in ["method", "target", "http_version", "reason"]: - if field in self.__dict__: - self.__dict__[field] = bytesify(self.__dict__[field]) - - if "status_code" in self.__dict__: - if not isinstance(self.status_code, int): - raise LocalProtocolError("status code must be integer") - # Because IntEnum objects are instances of int, but aren't - # duck-compatible (sigh), see gh-72. - self.status_code = int(self.status_code) - - self._validate() - - def _validate(self): - pass - - def __repr__(self): - name = self.__class__.__name__ - kwarg_strs = [ - "{}={}".format(field, self.__dict__[field]) for field in self._fields - ] - kwarg_str = ", ".join(kwarg_strs) - return "{}({})".format(name, kwarg_str) - - # Useful for tests - def __eq__(self, other): - return self.__class__ == other.__class__ and self.__dict__ == other.__dict__ - - # This is an unhashable type. - __hash__ = None - - -class Request(_EventBundle): +class Request: """The beginning of an HTTP request. Fields: @@ -123,10 +58,19 @@ class Request(_EventBundle): """ - _fields = ["method", "target", "headers", "http_version"] - _defaults = {"http_version": b"1.1"} + __slots__ = ("method", "target", "headers", "http_version") + + def __init__(self, method, target, headers, http_version=b"1.1", _parsed=False): + self.headers = _headers.normalize_and_validate(headers, _parsed=_parsed) + self.http_version = bytesify(http_version) + + if _parsed: + self.method = method + self.target = target + else: + self.method = bytesify(method) + self.target = bytesify(target) - def _validate(self): # "A server MUST respond with a 400 (Bad Request) status code to any # HTTP/1.1 request message that lacks a Host header field and to any # request message that contains more than one Host header field or a @@ -143,13 +87,31 @@ def _validate(self): validate(request_target_re, self.target, "Illegal target characters") + def __repr__(self): + return "{}(method={}, target={}, headers={}, http_version={})".format( + self.__class__.__name__, + self.method, + self.target, + self.headers, + self.http_version, + ) -class _ResponseBase(_EventBundle): - _fields = ["status_code", "headers", "http_version", "reason"] - _defaults = {"http_version": b"1.1", "reason": b""} + # Useful for tests + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return ( + self.method == other.method + and self.target == other.target + and self.headers == other.headers + and self.http_version == other.http_version + ) + + # This is an unhashable type. + __hash__ = None -class InformationalResponse(_ResponseBase): +class InformationalResponse: """An HTTP informational response. Fields: @@ -179,15 +141,57 @@ class InformationalResponse(_ResponseBase): """ - def _validate(self): + __slots__ = ("status_code", "headers", "http_version", "reason") + + def __init__( + self, status_code, headers, http_version=b"1.1", reason=b"", _parsed=False + ): + self.status_code = status_code + self.headers = _headers.normalize_and_validate(headers, _parsed=_parsed) + + if _parsed: + self.http_version = http_version + self.reason = reason + else: + self.http_version = bytesify(http_version) + self.reason = bytesify(reason) + if not isinstance(self.status_code, int): + raise LocalProtocolError("status code must be integer") + # Because IntEnum objects are instances of int, but aren't + # duck-compatible (sigh), see gh-72. + self.status_code = int(self.status_code) + if not (100 <= self.status_code < 200): raise LocalProtocolError( "InformationalResponse status_code should be in range " "[100, 200), not {}".format(self.status_code) ) + def __repr__(self): + return "{}(status_code={}, headers={}, http_version={}, reason={})".format( + self.__class__.__name__, + self.status_code, + self.headers, + self.http_version, + self.reason, + ) + + # Useful for tests + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return ( + self.status_code == other.status_code + and self.headers == other.headers + and self.http_version == other.http_version + and self.reason == other.reason + ) + + # This is an unhashable type. + __hash__ = None + -class Response(_ResponseBase): +class Response: """The beginning of an HTTP response. Fields: @@ -216,7 +220,26 @@ class Response(_ResponseBase): """ - def _validate(self): + __slots__ = ("status_code", "headers", "http_version", "reason") + + def __init__( + self, status_code, headers, http_version=b"1.1", reason=b"", _parsed=False + ): + self.status_code = status_code + self.headers = _headers.normalize_and_validate(headers, _parsed=_parsed) + + if _parsed: + self.http_version = http_version + self.reason = reason + else: + self.http_version = bytesify(http_version) + self.reason = bytesify(reason) + if not isinstance(self.status_code, int): + raise LocalProtocolError("status code must be integer") + # Because IntEnum objects are instances of int, but aren't + # duck-compatible (sigh), see gh-72. + self.status_code = int(self.status_code) + if not (200 <= self.status_code < 600): raise LocalProtocolError( "Response status_code should be in range [200, 600), not {}".format( @@ -224,8 +247,31 @@ def _validate(self): ) ) + def __repr__(self): + return "{}(status_code={}, headers={}, http_version={}, reason={})".format( + self.__class__.__name__, + self.status_code, + self.headers, + self.http_version, + self.reason, + ) + + # Useful for tests + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return ( + self.status_code == other.status_code + and self.headers == other.headers + and self.http_version == other.http_version + and self.reason == other.reason + ) + + # This is an unhashable type. + __hash__ = None + -class Data(_EventBundle): +class Data: """Part of an HTTP message body. Fields: @@ -258,8 +304,33 @@ class Data(_EventBundle): """ - _fields = ["data", "chunk_start", "chunk_end"] - _defaults = {"chunk_start": False, "chunk_end": False} + __slots__ = ("data", "chunk_start", "chunk_end") + + def __init__(self, data, chunk_start=False, chunk_end=False): + self.data = data + self.chunk_start = chunk_start + self.chunk_end = chunk_end + + def __repr__(self): + return "{}(data={}, chunk_start={}, chunk_end={})".format( + self.__class__.__name__, + self.data, + self.chunk_start, + self.chunk_end, + ) + + # Useful for tests + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return ( + self.data == other.data + and self.chunk_start == other.chunk_start + and self.chunk_end == other.chunk_end + ) + + # This is an unhashable type. + __hash__ = None # XX FIXME: "A recipient MUST ignore (or consider as an error) any fields that @@ -267,7 +338,7 @@ class Data(_EventBundle): # present in the header section might bypass external security filters." # https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#chunked.trailer.part # Unfortunately, the list of forbidden fields is long and vague :-/ -class EndOfMessage(_EventBundle): +class EndOfMessage: """The end of an HTTP message. Fields: @@ -284,11 +355,28 @@ class EndOfMessage(_EventBundle): """ - _fields = ["headers"] - _defaults = {"headers": []} + __slots__ = ("headers",) + def __init__(self, headers=[], _parsed=False): + self.headers = _headers.normalize_and_validate(headers, _parsed=_parsed) -class ConnectionClosed(_EventBundle): + def __repr__(self): + return "{}(headers={})".format( + self.__class__.__name__, + self.headers, + ) + + # Useful for tests + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return self.headers == other.headers + + # This is an unhashable type. + __hash__ = None + + +class ConnectionClosed: """This event indicates that the sender has closed their outgoing connection. @@ -299,4 +387,18 @@ class ConnectionClosed(_EventBundle): No fields. """ - pass + __slots__ = () + + def __repr__(self): + return "{}()".format( + self.__class__.__name__, + ) + + # Useful for tests + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + return True + + # This is an unhashable type. + __hash__ = None diff --git a/h11/tests/test_events.py b/h11/tests/test_events.py index e20f741..1c7c3a3 100644 --- a/h11/tests/test_events.py +++ b/h11/tests/test_events.py @@ -7,52 +7,6 @@ from .._util import LocalProtocolError -def test_event_bundle(): - class T(_events._EventBundle): - _fields = ["a", "b"] - _defaults = {"b": 1} - - def _validate(self): - if self.a == 0: - raise ValueError - - # basic construction and methods - t = T(a=1, b=0) - assert repr(t) == "T(a=1, b=0)" - assert t == T(a=1, b=0) - assert not (t == T(a=2, b=0)) - assert not (t != T(a=1, b=0)) - assert t != T(a=2, b=0) - with pytest.raises(TypeError): - hash(t) - - # check defaults - t = T(a=10) - assert t.a == 10 - assert t.b == 1 - - # no positional args - with pytest.raises(TypeError): - T(1) - - with pytest.raises(TypeError): - T(1, a=1, b=0) - - # unknown field - with pytest.raises(TypeError): - T(a=1, b=0, c=10) - - # missing required field - with pytest.raises(TypeError) as exc: - T(b=0) - # make sure we error on the right missing kwarg - assert "kwarg a" in str(exc.value) - - # _validate is called - with pytest.raises(ValueError): - T(a=0, b=0) - - def test_events(): with pytest.raises(LocalProtocolError): # Missing Host: @@ -66,6 +20,10 @@ def test_events(): assert req.target == b"/" assert req.headers == [(b"a", b"b")] assert req.http_version == b"1.0" + assert repr(req) == ( + "Request(method=b'GET', target=b'/', " + "headers=, http_version=b'1.0')" + ) # This is also okay -- has a Host (with weird capitalization, which is ok) req = Request( @@ -126,6 +84,10 @@ def test_events(): assert ir.status_code == 100 assert ir.headers == [(b"host", b"a")] assert ir.http_version == b"1.1" + assert repr(ir) == ( + "InformationalResponse(status_code=100, headers=, " + "http_version=b'1.1', reason=b'')" + ) with pytest.raises(LocalProtocolError): InformationalResponse(status_code=200, headers=[("Host", "a")]) @@ -146,9 +108,11 @@ def test_events(): d = Data(data=b"asdf") assert d.data == b"asdf" + assert repr(d) == "Data(data=b'asdf', chunk_start=False, chunk_end=False)" eom = EndOfMessage() assert eom.headers == [] + assert repr(eom) == "EndOfMessage(headers=)" cc = ConnectionClosed() assert repr(cc) == "ConnectionClosed()" From 40df2b6dbbf3f8a3bb847d3fa61cf20a94f0acba Mon Sep 17 00:00:00 2001 From: Ran Benita Date: Wed, 18 Nov 2020 18:25:54 +0200 Subject: [PATCH 3/3] Inline the validate() function It somewhat obscures the control flow and adds some non-trivial overhead. Inline it in favor of the direct match/if combination. On the bench/ microbenchmark: Before: 10700.2 requests/sec After : 11334.0 requests/sec --- h11/_events.py | 5 +++-- h11/_headers.py | 11 ++++++---- h11/_readers.py | 50 +++++++++++++++++++++++------------------- h11/_util.py | 10 --------- h11/tests/test_util.py | 31 -------------------------- 5 files changed, 38 insertions(+), 69 deletions(-) diff --git a/h11/_events.py b/h11/_events.py index a4a5c5a..73b7716 100644 --- a/h11/_events.py +++ b/h11/_events.py @@ -9,7 +9,7 @@ from . import _headers from ._abnf import request_target -from ._util import bytesify, LocalProtocolError, validate +from ._util import bytesify, LocalProtocolError # Everything in __all__ gets re-exported as part of the h11 public API. __all__ = [ @@ -85,7 +85,8 @@ def __init__(self, method, target, headers, http_version=b"1.1", _parsed=False): if host_count > 1: raise LocalProtocolError("Found multiple Host: headers") - validate(request_target_re, self.target, "Illegal target characters") + if request_target_re.fullmatch(self.target) is None: + raise LocalProtocolError("Illegal target characters") def __repr__(self): return "{}(method={}, target={}, headers={}, http_version={})".format( diff --git a/h11/_headers.py b/h11/_headers.py index 7ed39bc..793cd02 100644 --- a/h11/_headers.py +++ b/h11/_headers.py @@ -1,7 +1,7 @@ import re from ._abnf import field_name, field_value -from ._util import bytesify, LocalProtocolError, validate +from ._util import bytesify, LocalProtocolError # Facts # ----- @@ -127,8 +127,10 @@ def normalize_and_validate(headers, _parsed=False): if not _parsed: name = bytesify(name) value = bytesify(value) - validate(_field_name_re, name, "Illegal header name {!r}", name) - validate(_field_value_re, value, "Illegal header value {!r}", value) + if _field_name_re.fullmatch(name) is None: + raise LocalProtocolError("Illegal header name {!r}".format(name)) + if _field_value_re.fullmatch(value) is None: + raise LocalProtocolError("Illegal header value {!r}".format(value)) raw_name = name name = name.lower() if name == b"content-length": @@ -136,7 +138,8 @@ def normalize_and_validate(headers, _parsed=False): if len(lengths) != 1: raise LocalProtocolError("conflicting Content-Length headers") value = lengths.pop() - validate(_content_length_re, value, "bad Content-Length") + if _content_length_re.fullmatch(value) is None: + raise LocalProtocolError("bad Content-Length") if seen_content_length is None: seen_content_length = value new_headers.append((raw_name, name, value)) diff --git a/h11/_readers.py b/h11/_readers.py index 75f00bc..1ac0f29 100644 --- a/h11/_readers.py +++ b/h11/_readers.py @@ -3,9 +3,7 @@ # Strategy: each reader is a callable which takes a ReceiveBuffer object, and # either: # 1) consumes some of it and returns an Event -# 2) raises a LocalProtocolError (for consistency -- e.g. we call validate() -# and it might raise a LocalProtocolError, so simpler just to always use -# this) +# 2) raises a LocalProtocolError # 3) returns None, meaning "I need more data" # # If they have a .read_eof attribute, then this will be called if an EOF is @@ -21,7 +19,7 @@ from ._abnf import chunk_header, header_field, request_line, status_line from ._events import * from ._state import * -from ._util import LocalProtocolError, RemoteProtocolError, validate +from ._util import LocalProtocolError, RemoteProtocolError __all__ = ["READERS"] @@ -54,8 +52,10 @@ def _obsolete_line_fold(lines): def _decode_header_lines(lines): for line in _obsolete_line_fold(lines): - matches = validate(header_field_re, line, "illegal header line: {!r}", line) - yield (matches["field_name"], matches["field_value"]) + match = header_field_re.fullmatch(line) + if match is None: + raise LocalProtocolError("illegal header line: {!r}", line) + yield match.group("field_name", "field_value") request_line_re = re.compile(request_line.encode("ascii")) @@ -67,11 +67,15 @@ def maybe_read_from_IDLE_client(buf): return None if not lines: raise LocalProtocolError("no request line received") - matches = validate( - request_line_re, lines[0], "illegal request line: {!r}", lines[0] - ) + match = request_line_re.fullmatch(lines[0]) + if match is None: + raise LocalProtocolError("illegal request line: {!r}", lines[0]) return Request( - headers=list(_decode_header_lines(lines[1:])), _parsed=True, **matches + headers=list(_decode_header_lines(lines[1:])), + method=match.group("method"), + target=match.group("target"), + http_version=match.group("http_version"), + _parsed=True, ) @@ -84,14 +88,19 @@ def maybe_read_from_SEND_RESPONSE_server(buf): return None if not lines: raise LocalProtocolError("no response line received") - matches = validate(status_line_re, lines[0], "illegal status line: {!r}", lines[0]) + match = status_line_re.fullmatch(lines[0]) + if match is None: + raise LocalProtocolError("illegal status line: {!r}", lines[0]) # Tolerate missing reason phrases - if matches["reason"] is None: - matches["reason"] = b"" - status_code = matches["status_code"] = int(matches["status_code"]) + reason = match.group("reason") or b"" + status_code = int(match.group("status_code")) class_ = InformationalResponse if status_code < 200 else Response return class_( - headers=list(_decode_header_lines(lines[1:])), _parsed=True, **matches + status_code=status_code, + headers=list(_decode_header_lines(lines[1:])), + http_version=match.group("http_version"), + reason=reason, + _parsed=True, ) @@ -150,14 +159,11 @@ def __call__(self, buf): chunk_header = buf.maybe_extract_until_next(b"\r\n") if chunk_header is None: return None - matches = validate( - chunk_header_re, - chunk_header, - "illegal chunk header: {!r}", - chunk_header, - ) + match = chunk_header_re.fullmatch(chunk_header) + if match is None: + raise LocalProtocolError("illegal chunk header: {!r}", chunk_header) # XX FIXME: we discard chunk extensions. Does anyone care? - self._bytes_in_chunk = int(matches["chunk_size"], base=16) + self._bytes_in_chunk = int(match.group("chunk_size"), base=16) if self._bytes_in_chunk == 0: self._reading_trailer = True return self(buf) diff --git a/h11/_util.py b/h11/_util.py index eb1a5cd..edce8d6 100644 --- a/h11/_util.py +++ b/h11/_util.py @@ -2,7 +2,6 @@ "ProtocolError", "LocalProtocolError", "RemoteProtocolError", - "validate", "make_sentinel", "bytesify", ] @@ -80,15 +79,6 @@ class RemoteProtocolError(ProtocolError): pass -def validate(regex, data, msg="malformed data", *format_args): - match = regex.fullmatch(data) - if not match: - if format_args: - msg = msg.format(*format_args) - raise LocalProtocolError(msg) - return match.groupdict() - - # Sentinel values # # - Inherit identity-based comparison and hashing from object diff --git a/h11/tests/test_util.py b/h11/tests/test_util.py index d851bdc..931ad46 100644 --- a/h11/tests/test_util.py +++ b/h11/tests/test_util.py @@ -42,37 +42,6 @@ def thunk(): assert new_traceback.endswith(orig_traceback) -def test_validate(): - my_re = re.compile(br"(?P[0-9]+)\.(?P[0-9]+)") - with pytest.raises(LocalProtocolError): - validate(my_re, b"0.") - - groups = validate(my_re, b"0.1") - assert groups == {"group1": b"0", "group2": b"1"} - - # successful partial matches are an error - must match whole string - with pytest.raises(LocalProtocolError): - validate(my_re, b"0.1xx") - with pytest.raises(LocalProtocolError): - validate(my_re, b"0.1\n") - - -def test_validate_formatting(): - my_re = re.compile(br"foo") - - with pytest.raises(LocalProtocolError) as excinfo: - validate(my_re, b"", "oops") - assert "oops" in str(excinfo.value) - - with pytest.raises(LocalProtocolError) as excinfo: - validate(my_re, b"", "oops {}") - assert "oops {}" in str(excinfo.value) - - with pytest.raises(LocalProtocolError) as excinfo: - validate(my_re, b"", "oops {} xx", 10) - assert "oops 10 xx" in str(excinfo.value) - - def test_make_sentinel(): S = make_sentinel("S") assert repr(S) == "S"