diff --git a/src/hypercorn/protocol/http_stream.py b/src/hypercorn/protocol/http_stream.py index 1a68cdc5..7183854b 100644 --- a/src/hypercorn/protocol/http_stream.py +++ b/src/hypercorn/protocol/http_stream.py @@ -141,6 +141,15 @@ async def app_send(self, message: Optional[ASGISendEvent]) -> None: else: if message["type"] == "http.response.start" and self.state == ASGIHTTPState.REQUEST: self.response = message + headers = build_and_validate_headers(self.response.get("headers", [])) + await self.send( + Response( + stream_id=self.stream_id, + headers=headers, + status_code=int(self.response["status"]), + ) + ) + self.state = ASGIHTTPState.RESPONSE elif ( message["type"] == "http.response.push" and self.scope["http_version"] in PUSH_VERSIONS @@ -175,21 +184,7 @@ async def app_send(self, message: Optional[ASGISendEvent]) -> None: status_code=103, ) ) - elif message["type"] == "http.response.body" and self.state in { - ASGIHTTPState.REQUEST, - ASGIHTTPState.RESPONSE, - }: - if self.state == ASGIHTTPState.REQUEST: - headers = build_and_validate_headers(self.response.get("headers", [])) - await self.send( - Response( - stream_id=self.stream_id, - headers=headers, - status_code=int(self.response["status"]), - ) - ) - self.state = ASGIHTTPState.RESPONSE - + elif message["type"] == "http.response.body" and self.state == ASGIHTTPState.RESPONSE: if ( not suppress_body(self.scope["method"], int(self.response["status"])) and message.get("body", b"") != b"" diff --git a/tests/protocol/test_http_stream.py b/tests/protocol/test_http_stream.py index 5518c8b9..3deb4054 100644 --- a/tests/protocol/test_http_stream.py +++ b/tests/protocol/test_http_stream.py @@ -165,9 +165,7 @@ async def test_send_response(stream: HTTPStream) -> None: await stream.app_send( cast(HTTPResponseStartEvent, {"type": "http.response.start", "status": 200, "headers": []}) ) - assert stream.state == ASGIHTTPState.REQUEST - # Must wait for response before sending anything - stream.send.assert_not_called() # type: ignore + assert stream.state == ASGIHTTPState.RESPONSE await stream.app_send( cast(HTTPResponseBodyEvent, {"type": "http.response.body", "body": b"Body"}) ) @@ -413,15 +411,6 @@ async def test_closure(stream: HTTPStream) -> None: assert stream.app_put.call_args_list == [call({"type": "http.disconnect"})] -@pytest.mark.asyncio -async def test_closed_app_send_noop(stream: HTTPStream) -> None: - stream.closed = True - await stream.app_send( - cast(HTTPResponseStartEvent, {"type": "http.response.start", "status": 200, "headers": []}) - ) - stream.send.assert_not_called() # type: ignore - - @pytest.mark.asyncio async def test_abnormal_close_logging() -> None: config = Config()