Skip to content
This repository has been archived by the owner on Jan 13, 2021. It is now read-only.

Add ENABLE_PUSH flag in the Upgrade HTTP2-Settings header #310

Open
wants to merge 11 commits into
base: development
Choose a base branch
from
19 changes: 18 additions & 1 deletion hyper/common/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def __init__(self,
self._port = port
self._h1_kwargs = {
'secure': secure, 'ssl_context': ssl_context,
'proxy_host': proxy_host, 'proxy_port': proxy_port
'proxy_host': proxy_host, 'proxy_port': proxy_port,
'enable_push': enable_push
}
self._h2_kwargs = {
'window_manager': window_manager, 'enable_push': enable_push,
Expand Down Expand Up @@ -143,6 +144,22 @@ def get_response(self, *args, **kwargs):

return self._conn.get_response(1)

def get_pushes(self, *args, **kwargs):
try:
return self._conn.get_pushes(*args, **kwargs)
except HTTPUpgrade as e:
assert e.negotiated == H2C_PROTOCOL

self._conn = HTTP20Connection(
self._host, self._port, **self._h2_kwargs
)

self._conn._connect_upgrade(e.sock, True)
# stream id 1 is used by the upgrade request and response
# and is half-closed by the client

return self._conn.get_pushes(*args, **kwargs)

# The following two methods are the implementation of the context manager
# protocol.
def __enter__(self): # pragma: no cover
Expand Down
63 changes: 42 additions & 21 deletions hyper/http11/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class HTTP11Connection(object):
"""

version = HTTPVersion.http11
_response = None

def __init__(self, host, port=None, secure=None, ssl_context=None,
proxy_host=None, proxy_port=None, **kwargs):
Expand All @@ -78,6 +79,7 @@ def __init__(self, host, port=None, secure=None, ssl_context=None,

# only send http upgrade headers for non-secure connection
self._send_http_upgrade = not self.secure
self._enable_push = kwargs.get('enable_push')

self.ssl_context = ssl_context
self._sock = None
Expand All @@ -104,6 +106,12 @@ def __init__(self, host, port=None, secure=None, ssl_context=None,
#: the standard hyper parsing interface.
self.parser = Parser()

def get_pushes(self, stream_id=None, capture_all=False):
"""
Dummy method to trigger h2c upgrade.
"""
self._get_response()

def connect(self):
"""
Connect to the server specified when the object was created. This is a
Expand Down Expand Up @@ -188,6 +196,7 @@ def request(self, method, url, body=None, headers=None):
# Next, send the request body.
if body:
self._send_body(body, body_type)
self._response = None

return

Expand All @@ -198,31 +207,39 @@ def get_response(self):
This is an early beta, so the response object is pretty stupid. That's
ok, we'll fix it later.
"""
headers = HTTPHeaderMap()
resp = self._get_response()
self._response = None
return resp

response = None
while response is None:
# 'encourage' the socket to receive data.
self._sock.fill()
response = self.parser.parse_response(self._sock.buffer)
def _get_response(self):
if self._response is None:

for n, v in response.headers:
headers[n.tobytes()] = v.tobytes()
headers = HTTPHeaderMap()

self._sock.advance_buffer(response.consumed)
response = None
while response is None:
# 'encourage' the socket to receive data.
self._sock.fill()
response = self.parser.parse_response(self._sock.buffer)

if (response.status == 101 and
for n, v in response.headers:
headers[n.tobytes()] = v.tobytes()

self._sock.advance_buffer(response.consumed)

if (response.status == 101 and
b'upgrade' in headers['connection'] and
H2C_PROTOCOL.encode('utf-8') in headers['upgrade']):
raise HTTPUpgrade(H2C_PROTOCOL, self._sock)

return HTTP11Response(
response.status,
response.msg.tobytes(),
headers,
self._sock,
self
)
H2C_PROTOCOL.encode('utf-8') in headers['upgrade']):
raise HTTPUpgrade(H2C_PROTOCOL, self._sock)

self._response = HTTP11Response(
response.status,
response.msg.tobytes(),
headers,
self._sock,
self
)
return self._response

def _send_headers(self, method, url, headers):
"""
Expand Down Expand Up @@ -276,6 +293,10 @@ def _add_upgrade_headers(self, headers):
# Settings header.
http2_settings = SettingsFrame(0)
http2_settings.settings[SettingsFrame.INITIAL_WINDOW_SIZE] = 65535
if self._enable_push is not None:
http2_settings.settings[SettingsFrame.ENABLE_PUSH] = (
int(self._enable_push)
)
encoded_settings = base64.urlsafe_b64encode(
http2_settings.serialize_body()
)
Expand Down Expand Up @@ -348,7 +369,7 @@ def _send_file_like_obj(self, fobj):
Handles streaming a file-like object to the network.
"""
while True:
block = fobj.read(16*1024)
block = fobj.read(16 * 1024)
if not block:
break

Expand Down
11 changes: 9 additions & 2 deletions hyper/http20/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def __init__(self, host, port=None, secure=None, window_manager=None,
else:
self.secure = False

self._delay_recv = False
self._enable_push = enable_push
self.ssl_context = ssl_context

Expand Down Expand Up @@ -313,6 +314,9 @@ def get_response(self, stream_id=None):
get a response.
:returns: A :class:`HTTP20Response <hyper.HTTP20Response>` object.
"""
if self._delay_recv:
self._recv_cb()
self._delay_recv = False
stream = self._get_stream(stream_id)
return HTTP20Response(stream.getheaders(), stream)

Expand Down Expand Up @@ -384,7 +388,7 @@ def connect(self):

self._send_preamble()

def _connect_upgrade(self, sock):
def _connect_upgrade(self, sock, no_recv=False):
"""
Called by the generic HTTP connection when we're being upgraded. Locks
in a new socket and places the backing state machine into an upgrade
Expand All @@ -405,7 +409,10 @@ def _connect_upgrade(self, sock):
s = self._new_stream(local_closed=True)
self.recent_stream = s

self._recv_cb()
if no_recv: # To delay I/O operation
self._delay_recv = True
else:
self._recv_cb()

def _send_preamble(self):
"""
Expand Down
1 change: 1 addition & 0 deletions test/test_abstraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_h1_kwargs(self):
'proxy_host': False,
'proxy_port': False,
'other_kwarg': True,
'enable_push': True,
}

def test_h2_kwargs(self):
Expand Down
33 changes: 28 additions & 5 deletions test/test_hyper.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
PingFrame, FRAME_MAX_ALLOWED_LEN
)
from hpack.hpack_compat import Encoder
from hyper import HTTPConnection
from hyper.http20.connection import HTTP20Connection
from hyper.http20.response import HTTP20Response, HTTP20Push
from hyper.http20.exceptions import ConnectionError, StreamResetError
Expand Down Expand Up @@ -731,8 +732,8 @@ def add_data_frame(self, stream_id, data, end_stream=False):
frame.flags.add('END_STREAM')
self.frames.append(frame)

def request(self):
self.conn = HTTP20Connection('www.google.com', enable_push=True)
def request(self, enable_push=True):
self.conn = HTTP20Connection('www.google.com', enable_push=enable_push)
self.conn._sock = DummySocket()
self.conn._sock.buffer = BytesIO(
b''.join([frame.serialize() for frame in self.frames])
Expand Down Expand Up @@ -934,13 +935,13 @@ def test_reset_pushed_streams_when_push_disabled(self):
1, [(':status', '200'), ('content-type', 'text/html')]
)

self.request()
self.conn._enable_push = False
self.request(enable_push=False)
self.conn.get_response()

f = RstStreamFrame(2)
f.error_code = 7
assert self.conn._sock.queue[-1] == f.serialize()
print(self.conn._sock.queue)
assert self.conn._sock.queue[-1].endswith(f.serialize())

def test_pushed_requests_ignore_unexpected_headers(self):
headers = HTTPHeaderMap([
Expand All @@ -956,7 +957,29 @@ def test_pushed_requests_ignore_unexpected_headers(self):
assert p.request_headers == HTTPHeaderMap([('no', 'no')])


class TestUpgradingPush(TestServerPush):
http101 = (b"HTTP/1.1 101 Switching Protocols\r\n"
b"Connection: upgrade\r\n"
b"Upgrade: h2c\r\n"
b"\r\n")

def setup_method(self, method):
self.frames = [SettingsFrame(0)] # Server-side preface
self.encoder = Encoder()
self.conn = None

def request(self, enable_push=True):
self.conn = HTTPConnection('www.google.com', enable_push=enable_push)
self.conn._conn._sock = DummySocket()
self.conn._conn._sock.buffer = BytesIO(
self.http101 + b''.join([frame.serialize()
for frame in self.frames])
)
self.conn.request('GET', '/')


class TestResponse(object):

def test_status_is_stripped_from_headers(self):
headers = HTTPHeaderMap([(':status', '200')])
resp = HTTP20Response(headers, None)
Expand Down
Loading