From 62b77a1eacf031820093607bb2efa3c3e9740f63 Mon Sep 17 00:00:00 2001 From: Sachin Sagadevan Date: Sun, 3 Jan 2021 11:45:51 -0800 Subject: [PATCH] Fixing Issue #319 in hyper-h2, adding ability to enable/disable RFC8441 extension through H2Configuration. --- src/h2/config.py | 9 ++ src/h2/connection.py | 2 + src/h2/stream.py | 1 + src/h2/utilities.py | 24 +++- test/test_config.py | 11 ++ test/test_invalid_headers.py | 208 ++++++++++++++++++++++++++++++++++- test/test_rfc8441.py | 6 +- 7 files changed, 253 insertions(+), 8 deletions(-) diff --git a/src/h2/config.py b/src/h2/config.py index 730b61124..bb62b8ea0 100644 --- a/src/h2/config.py +++ b/src/h2/config.py @@ -129,6 +129,9 @@ class H2Configuration: normalize_inbound_headers = _BooleanConfigOption( 'normalize_inbound_headers' ) + enable_rfc8441 = _BooleanConfigOption( + 'enable_rfc8441' + ) def __init__(self, client_side=True, @@ -137,6 +140,7 @@ def __init__(self, normalize_outbound_headers=True, validate_inbound_headers=True, normalize_inbound_headers=True, + enable_rfc8441=False, logger=None): self.client_side = client_side self.header_encoding = header_encoding @@ -144,6 +148,7 @@ def __init__(self, self.normalize_outbound_headers = normalize_outbound_headers self.validate_inbound_headers = validate_inbound_headers self.normalize_inbound_headers = normalize_inbound_headers + self.enable_rfc8441 = enable_rfc8441 self.logger = logger or DummyLogger(__name__) @property @@ -168,3 +173,7 @@ def header_encoding(self, value): if value is True: raise ValueError("header_encoding cannot be True") self._header_encoding = value + + @property + def is_rfc8441_enabled(self): + return self.enable_rfc8441 \ No newline at end of file diff --git a/src/h2/connection.py b/src/h2/connection.py index aa3071144..20a9b6940 100644 --- a/src/h2/connection.py +++ b/src/h2/connection.py @@ -326,6 +326,8 @@ def __init__(self, config=None): self.DEFAULT_MAX_HEADER_LIST_SIZE, } ) + if self.config.is_rfc8441_enabled: + self.local_settings.enable_connect_protocol = 1 self.remote_settings = Settings(client=not self.config.client_side) # The current value of the connection flow control windows on the diff --git a/src/h2/stream.py b/src/h2/stream.py index 3c29b2431..3cbb3e0fe 100644 --- a/src/h2/stream.py +++ b/src/h2/stream.py @@ -1230,6 +1230,7 @@ def _build_hdr_validation_flags(self, events): is_trailer=is_trailer, is_response_header=is_response_header, is_push_promise=is_push_promise, + is_rfc8441_enabled=self.config.is_rfc8441_enabled, ) def _build_headers_frames(self, diff --git a/src/h2/utilities.py b/src/h2/utilities.py index eb07f575e..afe3cff1d 100644 --- a/src/h2/utilities.py +++ b/src/h2/utilities.py @@ -186,7 +186,7 @@ def authority_from_headers(headers): # should be applied to a given set of headers. HeaderValidationFlags = collections.namedtuple( 'HeaderValidationFlags', - ['is_client', 'is_trailer', 'is_response_header', 'is_push_promise'] + ['is_client', 'is_trailer', 'is_response_header', 'is_push_promise', 'is_rfc8441_enabled'] ) @@ -316,6 +316,18 @@ def _assert_header_in_set(string_header, bytes_header, header_set): ) +def _assert_header_not_in_set(string_header, bytes_header, header_set): + """ + Given a set of header names, checks whether the string or byte version of + the header name is not present. Raises a Protocol error with the appropriate + error if it's present. + """ + if (string_header in header_set or bytes_header in header_set): + raise ProtocolError( + "Header block must not contain %s header" % string_header + ) + + def _reject_pseudo_header_fields(headers, hdr_validation_flags): """ Raises a ProtocolError if duplicate pseudo-header fields are found in a @@ -396,9 +408,15 @@ def _check_pseudo_header_field_acceptability(pseudo_headers, not hdr_validation_flags.is_trailer): # This is a request, so we need to have seen :path, :method, and # :scheme. - _assert_header_in_set(u':path', b':path', pseudo_headers) _assert_header_in_set(u':method', b':method', pseudo_headers) - _assert_header_in_set(u':scheme', b':scheme', pseudo_headers) + if method == b'CONNECT': + _assert_header_in_set(u':authority', b':authority', pseudo_headers) + if method == b'CONNECT' and not hdr_validation_flags.is_rfc8441_enabled: + _assert_header_not_in_set(u':path', b':path', pseudo_headers) + _assert_header_not_in_set(u':scheme', b':scheme', pseudo_headers) + else: + _assert_header_in_set(u':path', b':path', pseudo_headers) + _assert_header_in_set(u':scheme', b':scheme', pseudo_headers) invalid_request_headers = pseudo_headers & _RESPONSE_ONLY_HEADERS if invalid_request_headers: raise ProtocolError( diff --git a/test/test_config.py b/test/test_config.py index 8eb7fdc86..3e6cb8be1 100644 --- a/test/test_config.py +++ b/test/test_config.py @@ -22,6 +22,7 @@ def test_defaults(self): config = h2.config.H2Configuration() assert config.client_side assert config.header_encoding is None + assert config.is_rfc8441_enabled is False assert isinstance(config.logger, h2.config.DummyLogger) boolean_config_options = [ @@ -30,6 +31,7 @@ def test_defaults(self): 'normalize_outbound_headers', 'validate_inbound_headers', 'normalize_inbound_headers', + 'enable_rfc8441', ] @pytest.mark.parametrize('option_name', boolean_config_options) @@ -120,6 +122,15 @@ def test_header_encoding_is_reflected_attr(self, header_encoding): config.header_encoding = header_encoding assert config.header_encoding == header_encoding + @pytest.mark.parametrize('enable_rfc8441', [False, True]) + def test_header_encoding_is_reflected_init(self, enable_rfc8441): + """ + The value of ``enable_rfc8441``, when set, is reflected in the value + via the initializer. + """ + config = h2.config.H2Configuration(enable_rfc8441=enable_rfc8441) + assert config.is_rfc8441_enabled == enable_rfc8441 + def test_logger_instance_is_reflected(self): """ The value of ``logger``, when set, is reflected in the value. diff --git a/test/test_invalid_headers.py b/test/test_invalid_headers.py index a37995073..eda8597fe 100644 --- a/test/test_invalid_headers.py +++ b/test/test_invalid_headers.py @@ -423,10 +423,10 @@ class TestFilter(object): hdr_validation_combos = [ h2.utilities.HeaderValidationFlags( - is_client, is_trailer, is_response_header, is_push_promise + is_client, is_trailer, is_response_header, is_push_promise, is_rfc8441_enabled ) - for is_client, is_trailer, is_response_header, is_push_promise in ( - itertools.product([True, False], repeat=4) + for is_client, is_trailer, is_response_header, is_push_promise, is_rfc8441_enabled in ( + itertools.product([True, False], repeat=5) ) ] @@ -494,6 +494,68 @@ class TestFilter(object): (u':path', u''), ), ) + invalid_connect_request_block_bytes = ( + # First, missing :authority with :protocol header + ( + (b':method', b'CONNECT'), + (b':protocol', b'test_value'), + (b'host', b'example.com'), + ), + # Next, missing :authority without :protocol header + ( + (b':method', b'CONNECT'), + (b'host', b'example.com'), + ) + ) + invalid_connect_request_block_unicode = ( + # First, missing :authority with :protocol header + ( + (u':method', u'CONNECT'), + (u':protocol', u'websocket'), + (u'host', u'example.com'), + ), + # Next, missing :authority without :protocol header + ( + (u':method', u'CONNECT'), + (u'host', u'example.com'), + ), + ) + invalid_connect_req_rfc8441_bytes = ( + # First, missing :path header + ( + (b':authority', b'example.com'), + (b':method', b'CONNECT'), + (b':protocol', b'test_value'), + (b':scheme', b'https'), + (b'host', b'example.com'), + ), + # Next, missing :scheme header + ( + (b':authority', b'example.com'), + (b':method', b'CONNECT'), + (b':protocol', b'test_value'), + (b':path', b'/'), + (b'host', b'example.com'), + ) + ) + invalid_connect_req_rfc8441_unicode = ( + # First, missing :path header + ( + (u':authority', u'example.com'), + (u':method', u'CONNECT'), + (u':protocol', u'test_value'), + (u':scheme', u'https'), + (u'host', u'example.com'), + ), + # Next, missing :scheme header + ( + (u':authority', u'example.com'), + (u':method', u'CONNECT'), + (u':protocol', u'test_value'), + (u':path', u'/'), + (u'host', u'example.com'), + ) + ) # All headers that are forbidden from either request or response blocks. forbidden_request_headers_bytes = (b':status',) @@ -504,6 +566,8 @@ class TestFilter(object): forbidden_response_headers_unicode = ( u':path', u':scheme', u':authority', u':method' ) + forbidden_connect_request_headers_bytes = (b':scheme', b':path') + forbidden_connect_request_headers_unicode = (u':scheme', u':path') @pytest.mark.parametrize('validation_function', validation_functions) @pytest.mark.parametrize('hdr_validation_flags', hdr_validation_combos) @@ -688,6 +752,144 @@ def test_inbound_resp_header_extra_pseudo_headers(self, with pytest.raises(h2.exceptions.ProtocolError): list(h2.utilities.validate_headers(headers, hdr_validation_flags)) + @pytest.mark.parametrize( + 'hdr_validation_flags', hdr_validation_request_headers_no_trailer + ) + @pytest.mark.parametrize( + 'header_block', ( + invalid_connect_request_block_bytes + + invalid_connect_request_block_unicode + ) + ) + def test_outbound_connect_req_missing_pseudo_headers(self, + hdr_validation_flags, + header_block): + if not hdr_validation_flags.is_rfc8441_enabled: + with pytest.raises(h2.exceptions.ProtocolError) as protocol_error: + list( + h2.utilities.validate_outbound_headers( + header_block, hdr_validation_flags + ) + ) + # Check if missing :path and :scheme headers doesn't throw ProtocolError exception + assert "missing mandatory :path header" not in str(protocol_error.value) + assert "missing mandatory :scheme header" not in str(protocol_error.value) + + @pytest.mark.parametrize( + 'hdr_validation_flags', hdr_validation_request_headers_no_trailer + ) + @pytest.mark.parametrize( + 'header_block', invalid_connect_request_block_bytes + ) + def test_inbound_connect_req_missing_pseudo_headers(self, + hdr_validation_flags, + header_block): + if not hdr_validation_flags.is_rfc8441_enabled: + with pytest.raises(h2.exceptions.ProtocolError) as protocol_error: + list( + h2.utilities.validate_headers( + header_block, hdr_validation_flags + ) + ) + # Check if missing :path and :scheme headers doesn't throw ProtocolError exception + assert "missing mandatory :path header" not in str(protocol_error.value) + assert "missing mandatory :scheme header" not in str(protocol_error.value) + + @pytest.mark.parametrize( + 'hdr_validation_flags', hdr_validation_request_headers_no_trailer + ) + @pytest.mark.parametrize( + 'invalid_header', + forbidden_connect_request_headers_bytes + forbidden_connect_request_headers_unicode + ) + def test_outbound_connect_req_extra_pseudo_headers(self, + hdr_validation_flags, + invalid_header): + """ + Inbound request header blocks containing the forbidden request headers + fail validation. + """ + headers = [ + (b':authority', b'google.com'), + (b':method', b'CONNECT'), + (b':protocol', b'websocket'), + ] + if not hdr_validation_flags.is_rfc8441_enabled: + headers.append((invalid_header, b'some value')) + with pytest.raises(h2.exceptions.ProtocolError) as protocol_error: + list(h2.utilities.validate_outbound_headers(headers, hdr_validation_flags)) + if isinstance(invalid_header, bytes): + expected_exception_string = (b'Header block must not contain ' + invalid_header + b' header')\ + .decode("utf-8") + else: + expected_exception_string = 'Header block must not contain ' + invalid_header + ' header' + assert expected_exception_string == str(protocol_error.value) + + @pytest.mark.parametrize( + 'hdr_validation_flags', hdr_validation_request_headers_no_trailer + ) + @pytest.mark.parametrize( + 'invalid_header', + forbidden_connect_request_headers_bytes + ) + def test_inbound_connect_req_extra_pseudo_headers(self, + hdr_validation_flags, + invalid_header): + """ + Inbound request header blocks containing the forbidden request headers + fail validation. + """ + headers = [ + (b':authority', b'google.com'), + (b':method', b'CONNECT'), + (b':protocol', b'some value'), + ] + if not hdr_validation_flags.is_rfc8441_enabled: + headers.append((invalid_header, b'some value')) + with pytest.raises(h2.exceptions.ProtocolError) as protocol_error: + list(h2.utilities.validate_headers(headers, hdr_validation_flags)) + assert (b'Header block must not contain ' + invalid_header + b' header').decode("utf-8") \ + == str(protocol_error.value) + + + @pytest.mark.parametrize( + 'hdr_validation_flags', hdr_validation_request_headers_no_trailer + ) + @pytest.mark.parametrize( + 'header_block', ( + invalid_connect_req_rfc8441_bytes + + invalid_connect_req_rfc8441_unicode + ) + ) + def test_outbound_connect_req_rfc8441_missing_pseudo_headers(self, + hdr_validation_flags, + header_block): + if hdr_validation_flags.is_rfc8441_enabled: + with pytest.raises(h2.exceptions.ProtocolError): + list( + h2.utilities.validate_outbound_headers( + header_block, hdr_validation_flags + ) + ) + + @pytest.mark.parametrize( + 'hdr_validation_flags', hdr_validation_request_headers_no_trailer + ) + @pytest.mark.parametrize( + 'header_block', invalid_connect_req_rfc8441_bytes + ) + def test_inbound_connect_req_rfc8441_missing_pseudo_headers(self, + hdr_validation_flags, + header_block): + if hdr_validation_flags.is_rfc8441_enabled: + print("here", header_block) + with pytest.raises(h2.exceptions.ProtocolError): + list( + h2.utilities.validate_headers( + header_block, hdr_validation_flags + ) + ) + class TestOversizedHeaders(object): """ diff --git a/test/test_rfc8441.py b/test/test_rfc8441.py index b2bf881fd..d65b3da72 100644 --- a/test/test_rfc8441.py +++ b/test/test_rfc8441.py @@ -26,12 +26,14 @@ def test_can_send_headers(self, frame_factory): (b'user-agent', b'someua/0.0.1'), ] - client = h2.connection.H2Connection() + client = h2.connection.H2Connection( + config=h2.config.H2Configuration(enable_rfc8441=True) + ) client.initiate_connection() client.send_headers(stream_id=1, headers=headers) server = h2.connection.H2Connection( - config=h2.config.H2Configuration(client_side=False) + config=h2.config.H2Configuration(client_side=False, enable_rfc8441=True) ) events = server.receive_data(client.data_to_send()) event = events[1]