diff --git a/tornado/netutil.py b/tornado/netutil.py index e83afad570..b0cf21c5a3 100644 --- a/tornado/netutil.py +++ b/tornado/netutil.py @@ -209,17 +209,20 @@ def bind_unix_socket( # Hurd doesn't support SO_REUSEADDR raise sock.setblocking(False) - try: - st = os.stat(file) - except FileNotFoundError: - pass - else: - if stat.S_ISSOCK(st.st_mode): - os.remove(file) + if not file.startswith("\0"): + try: + st = os.stat(file) + except FileNotFoundError: + pass else: - raise ValueError("File %s exists and is not a socket", file) - sock.bind(file) - os.chmod(file, mode) + if stat.S_ISSOCK(st.st_mode): + os.remove(file) + else: + raise ValueError("File %s exists and is not a socket", file) + sock.bind(file) + os.chmod(file, mode) + else: + sock.bind(file) sock.listen(backlog) return sock diff --git a/tornado/test/httpserver_test.py b/tornado/test/httpserver_test.py index 762c23f6e0..2d16cf2f0c 100644 --- a/tornado/test/httpserver_test.py +++ b/tornado/test/httpserver_test.py @@ -834,15 +834,14 @@ def setUp(self): super().setUp() self.tmpdir = tempfile.mkdtemp() self.sockfile = os.path.join(self.tmpdir, "test.sock") - sock = netutil.bind_unix_socket(self.sockfile) app = Application([("/hello", HelloWorldRequestHandler)]) self.server = HTTPServer(app) - self.server.add_socket(sock) - self.stream = IOStream(socket.socket(socket.AF_UNIX)) - self.io_loop.run_sync(lambda: self.stream.connect(self.sockfile)) + if sys.platform.startswith('linux'): + self.sockabstract = "\0" + os.path.basename(self.tmpdir) + self.server.add_socket(netutil.bind_unix_socket(self.sockabstract)) + self.server.add_socket(netutil.bind_unix_socket(self.sockfile)) def tearDown(self): - self.stream.close() self.io_loop.run_sync(self.server.close_all_connections) self.server.stop() shutil.rmtree(self.tmpdir) @@ -850,21 +849,38 @@ def tearDown(self): @gen_test def test_unix_socket(self): - self.stream.write(b"GET /hello HTTP/1.0\r\n\r\n") - response = yield self.stream.read_until(b"\r\n") - self.assertEqual(response, b"HTTP/1.1 200 OK\r\n") - header_data = yield self.stream.read_until(b"\r\n\r\n") - headers = HTTPHeaders.parse(header_data.decode("latin1")) - body = yield self.stream.read_bytes(int(headers["Content-Length"])) - self.assertEqual(body, b"Hello world") + with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream: + stream.connect(self.sockfile) + stream.write(b"GET /hello HTTP/1.0\r\n\r\n") + response = yield stream.read_until(b"\r\n") + self.assertEqual(response, b"HTTP/1.1 200 OK\r\n") + header_data = yield stream.read_until(b"\r\n\r\n") + headers = HTTPHeaders.parse(header_data.decode("latin1")) + body = yield stream.read_bytes(int(headers["Content-Length"])) + self.assertEqual(body, b"Hello world") + + @unittest.skipUnless(sys.platform.startswith('linux'), 'requires Linux') + @gen_test + def test_unix_socket_abstract(self): + with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream: + stream.connect(self.sockabstract) + stream.write(b"GET /hello HTTP/1.0\r\n\r\n") + response = yield stream.read_until(b"\r\n") + self.assertEqual(response, b"HTTP/1.1 200 OK\r\n") + header_data = yield stream.read_until(b"\r\n\r\n") + headers = HTTPHeaders.parse(header_data.decode("latin1")) + body = yield stream.read_bytes(int(headers["Content-Length"])) + self.assertEqual(body, b"Hello world") @gen_test def test_unix_socket_bad_request(self): # Unix sockets don't have remote addresses so they just return an # empty string. with ExpectLog(gen_log, "Malformed HTTP message from", level=logging.INFO): - self.stream.write(b"garbage\r\n\r\n") - response = yield self.stream.read_until_close() + with closing(IOStream(socket.socket(socket.AF_UNIX))) as stream: + stream.connect(self.sockfile) + stream.write(b"garbage\r\n\r\n") + response = yield stream.read_until_close() self.assertEqual(response, b"HTTP/1.1 400 Bad Request\r\n\r\n")