diff --git a/bin/pip_constraint_helpers.py b/bin/pip_constraint_helpers.py index 299f60b7db..10b202bb65 100644 --- a/bin/pip_constraint_helpers.py +++ b/bin/pip_constraint_helpers.py @@ -1,4 +1,5 @@ """A set of functions helping generating pip constraint files.""" + import functools import os import platform @@ -6,11 +7,11 @@ import sys PYTHON_IMPLEMENTATION_MAP = { # noqa: WPS407 - 'cpython': 'cp', - 'ironpython': 'ip', - 'jython': 'jy', - 'python': 'py', - 'pypy': 'pp', + "cpython": "cp", + "ironpython": "ip", + "jython": "jy", + "python": "py", + "pypy": "pp", } PYTHON_IMPLEMENTATION = platform.python_implementation() @@ -33,11 +34,10 @@ def get_runtime_python_tag(): python_tag_prefix = PYTHON_IMPLEMENTATION_MAP.get(sys_impl, sys_impl) # pylint: disable=possibly-unused-variable - python_minor_ver_tag = ''.join(map(str, python_minor_ver)) + python_minor_ver_tag = "".join(map(str, python_minor_ver)) return ( - '{python_tag_prefix!s}{python_minor_ver_tag!s}'. - format(**locals()) # noqa: WPS421 + "{python_tag_prefix!s}{python_minor_ver_tag!s}".format(**locals()) # noqa: WPS421 ) @@ -54,21 +54,20 @@ def get_constraint_file_path(req_dir, toxenv, python_tag): # pylint: disable=possibly-unused-variable platform_machine = platform.machine().lower() - if toxenv in {'py', 'python'}: - extra_prefix = 'py' if PYTHON_IMPLEMENTATION == 'PyPy' else '' - toxenv = '{prefix}py{ver}'.format( + if toxenv in {"py", "python"}: + extra_prefix = "py" if PYTHON_IMPLEMENTATION == "PyPy" else "" + toxenv = "{prefix}py{ver}".format( prefix=extra_prefix, ver=python_tag[2:], ) - if sys_platform == 'linux2': - sys_platform = 'linux' + if sys_platform == "linux2": + sys_platform = "linux" constraint_name = ( - 'tox-{toxenv}-{python_tag}-{sys_platform}-{platform_machine}'. - format(**locals()) # noqa: WPS421 + "tox-{toxenv}-{python_tag}-{sys_platform}-{platform_machine}".format(**locals()) # noqa: WPS421 ) - return os.path.join(req_dir, os.path.extsep.join((constraint_name, 'txt'))) + return os.path.join(req_dir, os.path.extsep.join((constraint_name, "txt"))) def make_pip_cmd(pip_args, constraint_file_path): @@ -79,14 +78,15 @@ def make_pip_cmd(pip_args, constraint_file_path): :returns: pip command. """ - pip_cmd = [sys.executable, '-m', 'pip'] + pip_args + pip_cmd = [sys.executable, "-m", "pip"] + pip_args if os.path.isfile(constraint_file_path): - pip_cmd += ['--constraint', constraint_file_path] + pip_cmd += ["--constraint", constraint_file_path] else: print_info( - 'WARNING: The expected pinned constraints file for the current ' - 'env does not exist (should be "{constraint_file_path}").'. - format(**locals()), # noqa: WPS421 + "WARNING: The expected pinned constraints file for the current " + 'env does not exist (should be "{constraint_file_path}").'.format( + **locals() + ), # noqa: WPS421 ) return pip_cmd @@ -97,7 +97,6 @@ def run_cmd(cmd): :param cmd: The command to invoke. """ print_info( - 'Invoking the following command: {cmd}'. - format(cmd=' '.join(cmd)), + "Invoking the following command: {cmd}".format(cmd=" ".join(cmd)), ) subprocess.check_call(cmd) # noqa: S603 diff --git a/cheroot/__init__.py b/cheroot/__init__.py index 4ae1d9ae11..55c2651453 100644 --- a/cheroot/__init__.py +++ b/cheroot/__init__.py @@ -4,6 +4,6 @@ try: - __version__ = metadata.version('cheroot') + __version__ = metadata.version("cheroot") except Exception: - __version__ = 'unknown' + __version__ = "unknown" diff --git a/cheroot/__main__.py b/cheroot/__main__.py index d2e27c1083..b56368ce14 100644 --- a/cheroot/__main__.py +++ b/cheroot/__main__.py @@ -2,5 +2,5 @@ from .cli import main -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/cheroot/_compat.py b/cheroot/_compat.py index dbe5c6d2ff..8bcfd7dfa9 100644 --- a/cheroot/_compat.py +++ b/cheroot/_compat.py @@ -7,44 +7,45 @@ try: import ssl + IS_ABOVE_OPENSSL10 = ssl.OPENSSL_VERSION_INFO >= (1, 1) del ssl except ImportError: IS_ABOVE_OPENSSL10 = None -IS_CI = bool(os.getenv('CI')) -IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv('GITHUB_WORKFLOW')) +IS_CI = bool(os.getenv("CI")) +IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv("GITHUB_WORKFLOW")) -IS_PYPY = platform.python_implementation() == 'PyPy' +IS_PYPY = platform.python_implementation() == "PyPy" SYS_PLATFORM = platform.system() -IS_WINDOWS = SYS_PLATFORM == 'Windows' -IS_LINUX = SYS_PLATFORM == 'Linux' -IS_MACOS = SYS_PLATFORM == 'Darwin' -IS_SOLARIS = SYS_PLATFORM == 'SunOS' +IS_WINDOWS = SYS_PLATFORM == "Windows" +IS_LINUX = SYS_PLATFORM == "Linux" +IS_MACOS = SYS_PLATFORM == "Darwin" +IS_SOLARIS = SYS_PLATFORM == "SunOS" PLATFORM_ARCH = platform.machine() -IS_PPC = PLATFORM_ARCH.startswith('ppc') +IS_PPC = PLATFORM_ARCH.startswith("ppc") -def ntob(n, encoding='ISO-8859-1'): +def ntob(n, encoding="ISO-8859-1"): """Return the native string as bytes in the given encoding.""" assert_native(n) # In Python 3, the native string type is unicode return n.encode(encoding) -def ntou(n, encoding='ISO-8859-1'): +def ntou(n, encoding="ISO-8859-1"): """Return the native string as Unicode with the given encoding.""" assert_native(n) # In Python 3, the native string type is unicode return n -def bton(b, encoding='ISO-8859-1'): +def bton(b, encoding="ISO-8859-1"): """Return the byte string as native string in the given encoding.""" return b.decode(encoding) @@ -57,7 +58,7 @@ def assert_native(n): """ if not isinstance(n, str): - raise TypeError('n must be a native str (got %s)' % type(n).__name__) + raise TypeError("n must be a native str (got %s)" % type(n).__name__) def extract_bytes(mv): @@ -80,5 +81,5 @@ def extract_bytes(mv): return mv raise ValueError( - 'extract_bytes() only accepts bytes and memoryview/buffer', + "extract_bytes() only accepts bytes and memoryview/buffer", ) diff --git a/cheroot/_compat.pyi b/cheroot/_compat.pyi index 67d93cf6c2..df1ca6b7ce 100644 --- a/cheroot/_compat.pyi +++ b/cheroot/_compat.pyi @@ -18,5 +18,4 @@ def ntob(n: str, encoding: str = ...) -> bytes: ... def ntou(n: str, encoding: str = ...) -> str: ... def bton(b: bytes, encoding: str = ...) -> str: ... def assert_native(n: str) -> None: ... - def extract_bytes(mv: Union[memoryview, bytes]) -> bytes: ... diff --git a/cheroot/cli.py b/cheroot/cli.py index cd168e9184..918e6e0fa2 100644 --- a/cheroot/cli.py +++ b/cheroot/cli.py @@ -69,7 +69,7 @@ class AbstractSocket(BindLocation): def __init__(self, abstract_socket): """Initialize.""" - self.bind_addr = '\x00{sock_path}'.format(sock_path=abstract_socket) + self.bind_addr = "\x00{sock_path}".format(sock_path=abstract_socket) class Application: @@ -78,8 +78,8 @@ class Application: @classmethod def resolve(cls, full_path): """Read WSGI app/Gateway path string and import application module.""" - mod_path, _, app_path = full_path.partition(':') - app = getattr(import_module(mod_path), app_path or 'application') + mod_path, _, app_path = full_path.partition(":") + app = getattr(import_module(mod_path), app_path or "application") # suppress the `TypeError` exception, just in case `app` is not a class with suppress(TypeError): if issubclass(app, server.Gateway): @@ -91,8 +91,8 @@ def __init__(self, wsgi_app): """Initialize.""" if not callable(wsgi_app): raise TypeError( - 'Application must be a callable object or ' - 'cheroot.server.Gateway subclass', + "Application must be a callable object or " + "cheroot.server.Gateway subclass", ) self.wsgi_app = wsgi_app @@ -101,7 +101,7 @@ def server_args(self, parsed_args): args = { arg: value for arg, value in vars(parsed_args).items() - if not arg.startswith('_') and value is not None + if not arg.startswith("_") and value is not None } args.update(vars(self)) return args @@ -121,11 +121,11 @@ def __init__(self, gateway): def server(self, parsed_args): """Server.""" server_args = vars(self) - server_args['bind_addr'] = parsed_args['bind_addr'] + server_args["bind_addr"] = parsed_args["bind_addr"] if parsed_args.max is not None: - server_args['maxthreads'] = parsed_args.max + server_args["maxthreads"] = parsed_args.max if parsed_args.numthreads is not None: - server_args['minthreads'] = parsed_args.numthreads + server_args["minthreads"] = parsed_args.numthreads return server.HTTPServer(**server_args) @@ -135,12 +135,12 @@ def parse_wsgi_bind_location(bind_addr_string): # this is the first condition to verify, otherwise the urlparse # validation would detect //@ as a valid url with a hostname # with value: "" and port: None - if bind_addr_string.startswith('@'): + if bind_addr_string.startswith("@"): return AbstractSocket(bind_addr_string[1:]) # try and match for an IP/hostname and port match = urllib.parse.urlparse( - '//{addr}'.format(addr=bind_addr_string), + "//{addr}".format(addr=bind_addr_string), ) try: addr = match.hostname @@ -160,69 +160,69 @@ def parse_wsgi_bind_addr(bind_addr_string): _arg_spec = { - '_wsgi_app': { - 'metavar': 'APP_MODULE', - 'type': Application.resolve, - 'help': 'WSGI application callable or cheroot.server.Gateway subclass', + "_wsgi_app": { + "metavar": "APP_MODULE", + "type": Application.resolve, + "help": "WSGI application callable or cheroot.server.Gateway subclass", }, - '--bind': { - 'metavar': 'ADDRESS', - 'dest': 'bind_addr', - 'type': parse_wsgi_bind_addr, - 'default': '[::1]:8000', - 'help': 'Network interface to listen on (default: [::1]:8000)', + "--bind": { + "metavar": "ADDRESS", + "dest": "bind_addr", + "type": parse_wsgi_bind_addr, + "default": "[::1]:8000", + "help": "Network interface to listen on (default: [::1]:8000)", }, - '--chdir': { - 'metavar': 'PATH', - 'type': os.chdir, - 'help': 'Set the working directory', + "--chdir": { + "metavar": "PATH", + "type": os.chdir, + "help": "Set the working directory", }, - '--server-name': { - 'dest': 'server_name', - 'type': str, - 'help': 'Web server name to be advertised via Server HTTP header', + "--server-name": { + "dest": "server_name", + "type": str, + "help": "Web server name to be advertised via Server HTTP header", }, - '--threads': { - 'metavar': 'INT', - 'dest': 'numthreads', - 'type': int, - 'help': 'Minimum number of worker threads', + "--threads": { + "metavar": "INT", + "dest": "numthreads", + "type": int, + "help": "Minimum number of worker threads", }, - '--max-threads': { - 'metavar': 'INT', - 'dest': 'max', - 'type': int, - 'help': 'Maximum number of worker threads', + "--max-threads": { + "metavar": "INT", + "dest": "max", + "type": int, + "help": "Maximum number of worker threads", }, - '--timeout': { - 'metavar': 'INT', - 'dest': 'timeout', - 'type': int, - 'help': 'Timeout in seconds for accepted connections', + "--timeout": { + "metavar": "INT", + "dest": "timeout", + "type": int, + "help": "Timeout in seconds for accepted connections", }, - '--shutdown-timeout': { - 'metavar': 'INT', - 'dest': 'shutdown_timeout', - 'type': int, - 'help': 'Time in seconds to wait for worker threads to cleanly exit', + "--shutdown-timeout": { + "metavar": "INT", + "dest": "shutdown_timeout", + "type": int, + "help": "Time in seconds to wait for worker threads to cleanly exit", }, - '--request-queue-size': { - 'metavar': 'INT', - 'dest': 'request_queue_size', - 'type': int, - 'help': 'Maximum number of queued connections', + "--request-queue-size": { + "metavar": "INT", + "dest": "request_queue_size", + "type": int, + "help": "Maximum number of queued connections", }, - '--accepted-queue-size': { - 'metavar': 'INT', - 'dest': 'accepted_queue_size', - 'type': int, - 'help': 'Maximum number of active requests in queue', + "--accepted-queue-size": { + "metavar": "INT", + "dest": "accepted_queue_size", + "type": int, + "help": "Maximum number of active requests in queue", }, - '--accepted-queue-timeout': { - 'metavar': 'INT', - 'dest': 'accepted_queue_timeout', - 'type': int, - 'help': 'Timeout in seconds for putting requests into queue', + "--accepted-queue-timeout": { + "metavar": "INT", + "dest": "accepted_queue_timeout", + "type": int, + "help": "Timeout in seconds for putting requests into queue", }, } @@ -230,14 +230,14 @@ def parse_wsgi_bind_addr(bind_addr_string): def main(): """Create a new Cheroot instance with arguments from the command line.""" parser = argparse.ArgumentParser( - description='Start an instance of the Cheroot WSGI/HTTP server.', + description="Start an instance of the Cheroot WSGI/HTTP server.", ) for arg, spec in _arg_spec.items(): parser.add_argument(arg, **spec) raw_args = parser.parse_args() # ensure cwd in sys.path - '' in sys.path or sys.path.insert(0, '') + "" in sys.path or sys.path.insert(0, "") # create a server based on the arguments provided raw_args._wsgi_app.server(raw_args).safe_start() diff --git a/cheroot/connections.py b/cheroot/connections.py index df70e6ea02..6e0040e274 100644 --- a/cheroot/connections.py +++ b/cheroot/connections.py @@ -18,6 +18,7 @@ try: from ctypes import windll, WinError import ctypes.wintypes + _SetHandleInformation = windll.kernel32.SetHandleInformation _SetHandleInformation.argtypes = [ ctypes.wintypes.HANDLE, @@ -26,6 +27,7 @@ ] _SetHandleInformation.restype = ctypes.wintypes.BOOL except ImportError: + def prevent_socket_inheritance(sock): """Stub inheritance prevention. @@ -33,11 +35,13 @@ def prevent_socket_inheritance(sock): """ pass else: + def prevent_socket_inheritance(sock): """Mark the given socket fd as non-inheritable (Windows).""" if not _SetHandleInformation(sock.fileno(), 1, 0): raise WinError() else: + def prevent_socket_inheritance(sock): """Mark the given socket fd as non-inheritable (POSIX).""" fd = sock.fileno() @@ -97,10 +101,7 @@ def select(self, timeout=None): Returns entries ready to read in the form: (socket_file_descriptor, connection) """ - return ( - (key.fd, key.data) - for key, _ in self._selector.select(timeout=timeout) - ) + return ((key.fd, key.data) for key, _ in self._selector.select(timeout=timeout)) def close(self): """Close the selector.""" @@ -129,7 +130,8 @@ def __init__(self, server): self._selector.register( server.socket.fileno(), - selectors.EVENT_READ, data=server, + selectors.EVENT_READ, + data=server, ) def put(self, conn): @@ -145,7 +147,9 @@ def put(self, conn): self.server.process_conn(conn) else: self._selector.register( - conn.socket.fileno(), selectors.EVENT_READ, data=conn, + conn.socket.fileno(), + selectors.EVENT_READ, + data=conn, ) def _expire(self, threshold): @@ -235,7 +239,7 @@ def _run(self, expiration_interval): self._remove_invalid_sockets() continue - for (sock_fd, conn) in active_list: + for sock_fd, conn in active_list: if conn is self.server: # New connection new_conn = self._from_server_socket(self.server.socket) @@ -280,10 +284,10 @@ def _remove_invalid_sockets(self): def _from_server_socket(self, server_socket): # noqa: C901 # FIXME try: s, addr = server_socket.accept() - if self.server.stats['Enabled']: - self.server.stats['Accepts'] += 1 + if self.server.stats["Enabled"]: + self.server.stats["Accepts"] += 1 prevent_socket_inheritance(s) - if hasattr(s, 'settimeout'): + if hasattr(s, "settimeout"): s.settimeout(self.server.timeout) mf = MakeFile @@ -294,39 +298,39 @@ def _from_server_socket(self, server_socket): # noqa: C901 # FIXME s, ssl_env = self.server.ssl_adapter.wrap(s) except errors.FatalSSLAlert as tls_connection_drop_error: self.server.error_log( - f'Client {addr !s} lost — peer dropped the TLS ' - 'connection suddenly, during handshake: ' - f'{tls_connection_drop_error !s}', + f"Client {addr !s} lost — peer dropped the TLS " + "connection suddenly, during handshake: " + f"{tls_connection_drop_error !s}", ) return except errors.NoSSLError as http_over_https_err: self.server.error_log( - f'Client {addr !s} attempted to speak plain HTTP into ' - 'a TCP connection configured for TLS-only traffic — ' - 'trying to send back a plain HTTP error response: ' - f'{http_over_https_err !s}', + f"Client {addr !s} attempted to speak plain HTTP into " + "a TCP connection configured for TLS-only traffic — " + "trying to send back a plain HTTP error response: " + f"{http_over_https_err !s}", ) msg = ( - 'The client sent a plain HTTP request, but ' - 'this server only speaks HTTPS on this port.' + "The client sent a plain HTTP request, but " + "this server only speaks HTTPS on this port." ) buf = [ - '%s 400 Bad Request\r\n' % self.server.protocol, - 'Content-Length: %s\r\n' % len(msg), - 'Content-Type: text/plain\r\n\r\n', + "%s 400 Bad Request\r\n" % self.server.protocol, + "Content-Length: %s\r\n" % len(msg), + "Content-Type: text/plain\r\n\r\n", msg, ] - wfile = mf(s, 'wb', io.DEFAULT_BUFFER_SIZE) + wfile = mf(s, "wb", io.DEFAULT_BUFFER_SIZE) try: - wfile.write(''.join(buf).encode('ISO-8859-1')) + wfile.write("".join(buf).encode("ISO-8859-1")) except OSError as ex: if ex.args[0] not in errors.socket_errors_to_ignore: raise return mf = self.server.ssl_adapter.makefile # Re-apply our timeout since we may have a new socket object - if hasattr(s, 'settimeout'): + if hasattr(s, "settimeout"): s.settimeout(self.server.timeout) conn = self.server.ConnectionClass(self.server, s, mf) @@ -338,10 +342,10 @@ def _from_server_socket(self, server_socket): # noqa: C901 # FIXME # figure out if AF_INET or AF_INET6. if len(s.getsockname()) == 2: # AF_INET - addr = ('0.0.0.0', 0) + addr = ("0.0.0.0", 0) else: # AF_INET6 - addr = ('::', 0) + addr = ("::", 0) conn.remote_addr = addr[0] conn.remote_port = addr[1] @@ -354,8 +358,8 @@ def _from_server_socket(self, server_socket): # noqa: C901 # FIXME # accept() by default return except OSError as ex: - if self.server.stats['Enabled']: - self.server.stats['Socket Errors'] += 1 + if self.server.stats["Enabled"]: + self.server.stats["Socket Errors"] += 1 if ex.args[0] in errors.socket_error_eintr: # I *think* this is right. EINTR should occur when a signal # is received during the accept() call; all docs say retry @@ -376,7 +380,7 @@ def _from_server_socket(self, server_socket): # noqa: C901 # FIXME def close(self): """Close all monitored connections.""" - for (_, conn) in self._selector.connections: + for _, conn in self._selector.connections: if conn is not self.server: # server closes its own socket conn.close() self._selector.close() diff --git a/cheroot/errors.py b/cheroot/errors.py index f6b588c2f5..87ab9bfe25 100644 --- a/cheroot/errors.py +++ b/cheroot/errors.py @@ -32,33 +32,44 @@ def plat_specific_errors(*errnames): return list(unique_nums - missing_attr) -socket_error_eintr = plat_specific_errors('EINTR', 'WSAEINTR') +socket_error_eintr = plat_specific_errors("EINTR", "WSAEINTR") socket_errors_to_ignore = plat_specific_errors( - 'EPIPE', - 'EBADF', 'WSAEBADF', - 'ENOTSOCK', 'WSAENOTSOCK', - 'ETIMEDOUT', 'WSAETIMEDOUT', - 'ECONNREFUSED', 'WSAECONNREFUSED', - 'ECONNRESET', 'WSAECONNRESET', - 'ECONNABORTED', 'WSAECONNABORTED', - 'ENETRESET', 'WSAENETRESET', - 'EHOSTDOWN', 'EHOSTUNREACH', + "EPIPE", + "EBADF", + "WSAEBADF", + "ENOTSOCK", + "WSAENOTSOCK", + "ETIMEDOUT", + "WSAETIMEDOUT", + "ECONNREFUSED", + "WSAECONNREFUSED", + "ECONNRESET", + "WSAECONNRESET", + "ECONNABORTED", + "WSAECONNABORTED", + "ENETRESET", + "WSAENETRESET", + "EHOSTDOWN", + "EHOSTUNREACH", ) -socket_errors_to_ignore.append('timed out') -socket_errors_to_ignore.append('The read operation timed out') +socket_errors_to_ignore.append("timed out") +socket_errors_to_ignore.append("The read operation timed out") socket_errors_nonblocking = plat_specific_errors( - 'EAGAIN', 'EWOULDBLOCK', 'WSAEWOULDBLOCK', + "EAGAIN", + "EWOULDBLOCK", + "WSAEWOULDBLOCK", ) -if sys.platform == 'darwin': - socket_errors_to_ignore.extend(plat_specific_errors('EPROTOTYPE')) - socket_errors_nonblocking.extend(plat_specific_errors('EPROTOTYPE')) +if sys.platform == "darwin": + socket_errors_to_ignore.extend(plat_specific_errors("EPROTOTYPE")) + socket_errors_nonblocking.extend(plat_specific_errors("EPROTOTYPE")) acceptable_sock_shutdown_error_codes = { errno.ENOTCONN, - errno.EPIPE, errno.ESHUTDOWN, # corresponds to BrokenPipeError in Python 3 + errno.EPIPE, + errno.ESHUTDOWN, # corresponds to BrokenPipeError in Python 3 errno.ECONNRESET, # corresponds to ConnectionResetError in Python 3 } """Errors that may happen during the connection close sequence. diff --git a/cheroot/makefile.py b/cheroot/makefile.py index 77878c13b7..e446ae1b58 100644 --- a/cheroot/makefile.py +++ b/cheroot/makefile.py @@ -25,7 +25,7 @@ def write(self, b): return len(b) def _flush_unlocked(self): - self._checkClosed('flush of closed file') + self._checkClosed("flush of closed file") while self._write_buf: try: # ssl sockets only except 'bytes', not bytearrays @@ -39,7 +39,7 @@ def _flush_unlocked(self): class StreamReader(io.BufferedReader): """Socket stream reader.""" - def __init__(self, sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE): + def __init__(self, sock, mode="r", bufsize=io.DEFAULT_BUFFER_SIZE): """Initialize socket stream reader.""" super().__init__(socket.SocketIO(sock, mode), bufsize) self.bytes_read = 0 @@ -58,7 +58,7 @@ def has_data(self): class StreamWriter(BufferedWriter): """Socket stream writer.""" - def __init__(self, sock, mode='w', bufsize=io.DEFAULT_BUFFER_SIZE): + def __init__(self, sock, mode="w", bufsize=io.DEFAULT_BUFFER_SIZE): """Initialize socket stream writer.""" super().__init__(socket.SocketIO(sock, mode), bufsize) self.bytes_written = 0 @@ -70,7 +70,7 @@ def write(self, val, *args, **kwargs): return res -def MakeFile(sock, mode='r', bufsize=io.DEFAULT_BUFFER_SIZE): +def MakeFile(sock, mode="r", bufsize=io.DEFAULT_BUFFER_SIZE): """File object attached to a socket object.""" - cls = StreamReader if 'r' in mode else StreamWriter + cls = StreamReader if "r" in mode else StreamWriter return cls(sock, mode, bufsize) diff --git a/cheroot/server.py b/cheroot/server.py index 91564611c0..bb526c2d5f 100644 --- a/cheroot/server.py +++ b/cheroot/server.py @@ -89,18 +89,24 @@ __all__ = ( - 'HTTPRequest', 'HTTPConnection', 'HTTPServer', - 'HeaderReader', 'DropUnderscoreHeaderReader', - 'SizeCheckWrapper', 'KnownLengthRFile', 'ChunkedRFile', - 'Gateway', 'get_ssl_adapter_class', + "HTTPRequest", + "HTTPConnection", + "HTTPServer", + "HeaderReader", + "DropUnderscoreHeaderReader", + "SizeCheckWrapper", + "KnownLengthRFile", + "ChunkedRFile", + "Gateway", + "get_ssl_adapter_class", ) -IS_WINDOWS = platform.system() == 'Windows' +IS_WINDOWS = platform.system() == "Windows" """Flag indicating whether the app is running under Windows.""" -IS_GAE = os.getenv('SERVER_SOFTWARE', '').startswith('Google App Engine/') +IS_GAE = os.getenv("SERVER_SOFTWARE", "").startswith("Google App Engine/") """Flag indicating whether the app is running in GAE env. Ref: @@ -128,14 +134,14 @@ import struct -if IS_WINDOWS and hasattr(socket, 'AF_INET6'): - if not hasattr(socket, 'IPPROTO_IPV6'): +if IS_WINDOWS and hasattr(socket, "AF_INET6"): + if not hasattr(socket, "IPPROTO_IPV6"): socket.IPPROTO_IPV6 = 41 - if not hasattr(socket, 'IPV6_V6ONLY'): + if not hasattr(socket, "IPV6_V6ONLY"): socket.IPV6_V6ONLY = 27 -if not hasattr(socket, 'SO_PEERCRED'): +if not hasattr(socket, "SO_PEERCRED"): """ NOTE: the value for SO_PEERCRED can be architecture specific, in which case the getsockopt() will hopefully fail. The arch @@ -144,33 +150,50 @@ socket.SO_PEERCRED = 21 if IS_PPC else 17 -LF = b'\n' -CRLF = b'\r\n' -TAB = b'\t' -SPACE = b' ' -COLON = b':' -SEMICOLON = b';' -EMPTY = b'' -ASTERISK = b'*' -FORWARD_SLASH = b'/' -QUOTED_SLASH = b'%2F' -QUOTED_SLASH_REGEX = re.compile(b''.join((b'(?i)', QUOTED_SLASH))) +LF = b"\n" +CRLF = b"\r\n" +TAB = b"\t" +SPACE = b" " +COLON = b":" +SEMICOLON = b";" +EMPTY = b"" +ASTERISK = b"*" +FORWARD_SLASH = b"/" +QUOTED_SLASH = b"%2F" +QUOTED_SLASH_REGEX = re.compile(b"".join((b"(?i)", QUOTED_SLASH))) _STOPPING_FOR_INTERRUPT = Exception() # sentinel used during shutdown comma_separated_headers = [ - b'Accept', b'Accept-Charset', b'Accept-Encoding', - b'Accept-Language', b'Accept-Ranges', b'Allow', b'Cache-Control', - b'Connection', b'Content-Encoding', b'Content-Language', b'Expect', - b'If-Match', b'If-None-Match', b'Pragma', b'Proxy-Authenticate', b'TE', - b'Trailer', b'Transfer-Encoding', b'Upgrade', b'Vary', b'Via', b'Warning', - b'WWW-Authenticate', + b"Accept", + b"Accept-Charset", + b"Accept-Encoding", + b"Accept-Language", + b"Accept-Ranges", + b"Allow", + b"Cache-Control", + b"Connection", + b"Content-Encoding", + b"Content-Language", + b"Expect", + b"If-Match", + b"If-None-Match", + b"Pragma", + b"Proxy-Authenticate", + b"TE", + b"Trailer", + b"Transfer-Encoding", + b"Upgrade", + b"Vary", + b"Via", + b"Warning", + b"WWW-Authenticate", ] -if not hasattr(logging, 'statistics'): +if not hasattr(logging, "statistics"): logging.statistics = {} @@ -201,13 +224,13 @@ def __call__(self, rfile, hdict=None): # noqa: C901 # FIXME line = rfile.readline() if not line: # No more data--illegal end of headers - raise ValueError('Illegal end of headers.') + raise ValueError("Illegal end of headers.") if line == CRLF: # Normal end of headers break if not line.endswith(CRLF): - raise ValueError('HTTP requires CRLF terminators') + raise ValueError("HTTP requires CRLF terminators") if line[:1] in (SPACE, TAB): # NOTE: `type(line[0]) is int` and `type(line[:1]) is bytes`. @@ -220,7 +243,7 @@ def __call__(self, rfile, hdict=None): # noqa: C901 # FIXME try: k, v = line.split(COLON, 1) except ValueError: - raise ValueError('Illegal header line.') + raise ValueError("Illegal header line.") v = v.strip() k = self._transform_key(k) hname = k @@ -231,7 +254,7 @@ def __call__(self, rfile, hdict=None): # noqa: C901 # FIXME if k in comma_separated_headers: existing = hdict.get(hname) if existing: - v = b', '.join((existing, v)) + v = b", ".join((existing, v)) hdict[hname] = v return hdict @@ -249,7 +272,7 @@ class DropUnderscoreHeaderReader(HeaderReader): def _allow_header(self, key_name): orig = super(DropUnderscoreHeaderReader, self)._allow_header(key_name) - return orig and '_' not in key_name + return orig and "_" not in key_name class SizeCheckWrapper: @@ -371,7 +394,7 @@ def read(self, size=None): :returns: chunk from ``rfile``, limited by size if specified """ if self.remaining == 0: - return b'' + return b"" if size is None: size = self.remaining else: @@ -391,7 +414,7 @@ def readline(self, size=None): :rtype: bytes """ if self.remaining == 0: - return b'' + return b"" if size is None: size = self.remaining else: @@ -469,7 +492,8 @@ def _fetch(self): if self.maxlen and self.bytes_read > self.maxlen: raise errors.MaxSizeExceeded( - 'Request Entity Too Large', self.maxlen, + "Request Entity Too Large", + self.maxlen, ) line = line.strip().split(SEMICOLON, 1) @@ -479,18 +503,19 @@ def _fetch(self): chunk_size = int(chunk_size, 16) except ValueError: raise ValueError( - 'Bad chunked transfer size: {chunk_size!r}'. - format(chunk_size=chunk_size), + "Bad chunked transfer size: {chunk_size!r}".format( + chunk_size=chunk_size + ), ) if chunk_size <= 0: self.closed = True return -# if line: chunk_extension = line[0] + # if line: chunk_extension = line[0] if self.maxlen and self.bytes_read + chunk_size > self.maxlen: - raise IOError('Request Entity Too Large') + raise IOError("Request Entity Too Large") chunk = self.rfile.read(chunk_size) self.bytes_read += len(chunk) @@ -500,7 +525,7 @@ def _fetch(self): if crlf != CRLF: raise ValueError( "Bad chunked transfer coding (expected '\\r\\n', " - 'got ' + repr(crlf) + ')', + "got " + repr(crlf) + ")", ) def read(self, size=None): @@ -607,24 +632,24 @@ def read_trailer_lines(self): """ if not self.closed: raise ValueError( - 'Cannot read trailers until the request body has been read.', + "Cannot read trailers until the request body has been read.", ) while True: line = self.rfile.readline() if not line: # No more data--illegal end of headers - raise ValueError('Illegal end of headers.') + raise ValueError("Illegal end of headers.") self.bytes_read += len(line) if self.maxlen and self.bytes_read > self.maxlen: - raise IOError('Request Entity Too Large') + raise IOError("Request Entity Too Large") if line == CRLF: # Normal end of headers break if not line.endswith(CRLF): - raise ValueError('HTTP requires CRLF terminators') + raise ValueError("HTTP requires CRLF terminators") yield line @@ -688,14 +713,14 @@ def __init__(self, server, conn, proxy_mode=False, strict_mode=True): self.ready = False self.started_request = False - self.scheme = b'http' + self.scheme = b"http" if self.server.ssl_adapter is not None: - self.scheme = b'https' + self.scheme = b"https" # Use the lowest-common protocol in case read_request_line errors. - self.response_protocol = 'HTTP/1.0' + self.response_protocol = "HTTP/1.0" self.inheaders = {} - self.status = '' + self.status = "" self.outheaders = [] self.sent_headers = False self.close_connection = self.__class__.close_connection @@ -714,9 +739,9 @@ def parse_request(self): success = self.read_request_line() except errors.MaxSizeExceeded: self.simple_response( - '414 Request-URI Too Long', - 'The Request-URI sent with the request exceeds the maximum ' - 'allowed bytes.', + "414 Request-URI Too Long", + "The Request-URI sent with the request exceeds the maximum " + "allowed bytes.", ) return else: @@ -727,9 +752,9 @@ def parse_request(self): success = self.read_request_headers() except errors.MaxSizeExceeded: self.simple_response( - '413 Request Entity Too Large', - 'The headers sent with the request exceed the maximum ' - 'allowed bytes.', + "413 Request Entity Too Large", + "The headers sent with the request exceed the maximum " + "allowed bytes.", ) return else: @@ -771,31 +796,35 @@ def read_request_line(self): # noqa: C901 # FIXME if not request_line.endswith(CRLF): self.simple_response( - '400 Bad Request', 'HTTP requires CRLF terminators', + "400 Bad Request", + "HTTP requires CRLF terminators", ) return False try: method, uri, req_protocol = request_line.strip().split(SPACE, 2) - if not req_protocol.startswith(b'HTTP/'): + if not req_protocol.startswith(b"HTTP/"): self.simple_response( - '400 Bad Request', 'Malformed Request-Line: bad protocol', + "400 Bad Request", + "Malformed Request-Line: bad protocol", ) return False - rp = req_protocol[5:].split(b'.', 1) + rp = req_protocol[5:].split(b".", 1) if len(rp) != 2: self.simple_response( - '400 Bad Request', 'Malformed Request-Line: bad version', + "400 Bad Request", + "Malformed Request-Line: bad version", ) return False rp = tuple(map(int, rp)) # Minor.Major must be threat as integers if rp > (1, 1): self.simple_response( - '505 HTTP Version Not Supported', 'Cannot fulfill request', + "505 HTTP Version Not Supported", + "Cannot fulfill request", ) return False except (ValueError, IndexError): - self.simple_response('400 Bad Request', 'Malformed Request-Line') + self.simple_response("400 Bad Request", "Malformed Request-Line") return False self.uri = uri @@ -803,23 +832,23 @@ def read_request_line(self): # noqa: C901 # FIXME if self.strict_mode and method != self.method: resp = ( - 'Malformed method name: According to RFC 2616 ' - '(section 5.1.1) and its successors ' - 'RFC 7230 (section 3.1.1) and RFC 7231 (section 4.1) ' - 'method names are case-sensitive and uppercase.' + "Malformed method name: According to RFC 2616 " + "(section 5.1.1) and its successors " + "RFC 7230 (section 3.1.1) and RFC 7231 (section 4.1) " + "method names are case-sensitive and uppercase." ) - self.simple_response('400 Bad Request', resp) + self.simple_response("400 Bad Request", resp) return False try: scheme, authority, path, qs, fragment = urllib.parse.urlsplit(uri) except UnicodeError: - self.simple_response('400 Bad Request', 'Malformed Request-URI') + self.simple_response("400 Bad Request", "Malformed Request-URI") return False - uri_is_absolute_form = (scheme or authority) + uri_is_absolute_form = scheme or authority - if self.method == b'OPTIONS': + if self.method == b"OPTIONS": # TODO: cover this branch with tests path = ( uri @@ -827,15 +856,15 @@ def read_request_line(self): # noqa: C901 # FIXME if (self.proxy_mode and uri_is_absolute_form) else path ) - elif self.method == b'CONNECT': + elif self.method == b"CONNECT": # TODO: cover this branch with tests if not self.proxy_mode: - self.simple_response('405 Method Not Allowed') + self.simple_response("405 Method Not Allowed") return False # `urlsplit()` above parses "example.com:3128" as path part of URI. # this is a workaround, which makes it detect netloc correctly - uri_split = urllib.parse.urlsplit(b''.join((b'//', uri))) + uri_split = urllib.parse.urlsplit(b"".join((b"//", uri))) _scheme, _authority, _path, _qs, _fragment = uri_split _port = EMPTY try: @@ -848,15 +877,13 @@ def read_request_line(self): # noqa: C901 # FIXME # invalid URIs without raising errors # https://tools.ietf.org/html/rfc7230#section-5.3.3 invalid_path = ( - _authority != uri - or not _port - or any((_scheme, _path, _qs, _fragment)) + _authority != uri or not _port or any((_scheme, _path, _qs, _fragment)) ) if invalid_path: self.simple_response( - '400 Bad Request', - 'Invalid path in Request-URI: request-' - 'target must match authority-form.', + "400 Bad Request", + "Invalid path in Request-URI: request-" + "target must match authority-form.", ) return False @@ -864,17 +891,15 @@ def read_request_line(self): # noqa: C901 # FIXME scheme = qs = fragment = EMPTY else: disallowed_absolute = ( - self.strict_mode - and not self.proxy_mode - and uri_is_absolute_form + self.strict_mode and not self.proxy_mode and uri_is_absolute_form ) if disallowed_absolute: # https://tools.ietf.org/html/rfc7230#section-5.3.2 # (absolute form) """Absolute URI is only allowed within proxies.""" self.simple_response( - '400 Bad Request', - 'Absolute URI not allowed if server is not a proxy.', + "400 Bad Request", + "Absolute URI not allowed if server is not a proxy.", ) return False @@ -888,25 +913,25 @@ def read_request_line(self): # noqa: C901 # FIXME # (origin_form) and """Path should start with a forward slash.""" resp = ( - 'Invalid path in Request-URI: request-target must contain ' - 'origin-form which starts with absolute-path (URI ' + "Invalid path in Request-URI: request-target must contain " + "origin-form which starts with absolute-path (URI " 'starting with a slash "/").' ) - self.simple_response('400 Bad Request', resp) + self.simple_response("400 Bad Request", resp) return False if fragment: self.simple_response( - '400 Bad Request', - 'Illegal #fragment in Request-URI.', + "400 Bad Request", + "Illegal #fragment in Request-URI.", ) return False if path is None: # FIXME: It looks like this case cannot happen self.simple_response( - '400 Bad Request', - 'Invalid path in Request-URI.', + "400 Bad Request", + "Invalid path in Request-URI.", ) return False @@ -926,7 +951,7 @@ def read_request_line(self): # noqa: C901 # FIXME for x in QUOTED_SLASH_REGEX.split(path) ] except ValueError as ex: - self.simple_response('400 Bad Request', ex.args[0]) + self.simple_response("400 Bad Request", ex.args[0]) return False path = QUOTED_SLASH.join(atoms) @@ -957,11 +982,11 @@ def read_request_line(self): # noqa: C901 # FIXME sp = int(self.server.protocol[5]), int(self.server.protocol[7]) if sp[0] != rp[0]: - self.simple_response('505 HTTP Version Not Supported') + self.simple_response("505 HTTP Version Not Supported") return False self.request_protocol = req_protocol - self.response_protocol = 'HTTP/%s.%s' % min(rp, sp) + self.response_protocol = "HTTP/%s.%s" % min(rp, sp) return True @@ -977,55 +1002,55 @@ def read_request_headers(self): # noqa: C901 # FIXME try: self.header_reader(self.rfile, self.inheaders) except ValueError as ex: - self.simple_response('400 Bad Request', ex.args[0]) + self.simple_response("400 Bad Request", ex.args[0]) return False mrbs = self.server.max_request_body_size try: - cl = int(self.inheaders.get(b'Content-Length', 0)) + cl = int(self.inheaders.get(b"Content-Length", 0)) except ValueError: self.simple_response( - '400 Bad Request', - 'Malformed Content-Length Header.', + "400 Bad Request", + "Malformed Content-Length Header.", ) return False if mrbs and cl > mrbs: self.simple_response( - '413 Request Entity Too Large', - 'The entity sent with the request exceeds the maximum ' - 'allowed bytes.', + "413 Request Entity Too Large", + "The entity sent with the request exceeds the maximum " + "allowed bytes.", ) return False # Persistent connection support - if self.response_protocol == 'HTTP/1.1': + if self.response_protocol == "HTTP/1.1": # Both server and client are HTTP/1.1 - if self.inheaders.get(b'Connection', b'') == b'close': + if self.inheaders.get(b"Connection", b"") == b"close": self.close_connection = True else: # Either the server or client (or both) are HTTP/1.0 - if self.inheaders.get(b'Connection', b'') != b'Keep-Alive': + if self.inheaders.get(b"Connection", b"") != b"Keep-Alive": self.close_connection = True # Transfer-Encoding support te = None - if self.response_protocol == 'HTTP/1.1': - te = self.inheaders.get(b'Transfer-Encoding') + if self.response_protocol == "HTTP/1.1": + te = self.inheaders.get(b"Transfer-Encoding") if te: - te = [x.strip().lower() for x in te.split(b',') if x.strip()] + te = [x.strip().lower() for x in te.split(b",") if x.strip()] self.chunked_read = False if te: for enc in te: - if enc == b'chunked': + if enc == b"chunked": self.chunked_read = True else: # Note that, even if we see "chunked", we must reject # if there is an extension we don't recognize. - self.simple_response('501 Unimplemented') + self.simple_response("501 Unimplemented") self.close_connection = True return False @@ -1046,14 +1071,19 @@ def read_request_headers(self): # noqa: C901 # FIXME # # We used to do 3, but are now doing 1. Maybe we'll do 2 someday, # but it seems like it would be a big slowdown for such a rare case. - if self.inheaders.get(b'Expect', b'') == b'100-continue': + if self.inheaders.get(b"Expect", b"") == b"100-continue": # Don't use simple_response here, because it emits headers # we don't want. See # https://github.com/cherrypy/cherrypy/issues/951 - msg = b''.join(( - self.server.protocol.encode('ascii'), SPACE, b'100 Continue', - CRLF, CRLF, - )) + msg = b"".join( + ( + self.server.protocol.encode("ascii"), + SPACE, + b"100 Continue", + CRLF, + CRLF, + ) + ) try: self.conn.wfile.write(msg) except socket.error as ex: @@ -1067,13 +1097,13 @@ def respond(self): if self.chunked_read: self.rfile = ChunkedRFile(self.conn.rfile, mrbs) else: - cl = int(self.inheaders.get(b'Content-Length', 0)) + cl = int(self.inheaders.get(b"Content-Length", 0)) if mrbs and mrbs < cl: if not self.sent_headers: self.simple_response( - '413 Request Entity Too Large', - 'The entity sent with the request exceeds the ' - 'maximum allowed bytes.', + "413 Request Entity Too Large", + "The entity sent with the request exceeds the " + "maximum allowed bytes.", ) return self.rfile = KnownLengthRFile(self.conn.rfile, cl) @@ -1082,37 +1112,37 @@ def respond(self): self.ready and self.ensure_headers_sent() if self.chunked_write: - self.conn.wfile.write(b'0\r\n\r\n') + self.conn.wfile.write(b"0\r\n\r\n") - def simple_response(self, status, msg=''): + def simple_response(self, status, msg=""): """Write a simple response back to the client.""" status = str(status) - proto_status = '%s %s\r\n' % (self.server.protocol, status) - content_length = 'Content-Length: %s\r\n' % len(msg) - content_type = 'Content-Type: text/plain\r\n' + proto_status = "%s %s\r\n" % (self.server.protocol, status) + content_length = "Content-Length: %s\r\n" % len(msg) + content_type = "Content-Type: text/plain\r\n" buf = [ - proto_status.encode('ISO-8859-1'), - content_length.encode('ISO-8859-1'), - content_type.encode('ISO-8859-1'), + proto_status.encode("ISO-8859-1"), + content_length.encode("ISO-8859-1"), + content_type.encode("ISO-8859-1"), ] - if status[:3] in ('413', '414'): + if status[:3] in ("413", "414"): # Request Entity Too Large / Request-URI Too Long self.close_connection = True - if self.response_protocol == 'HTTP/1.1': + if self.response_protocol == "HTTP/1.1": # This will not be true for 414, since read_request_line # usually raises 414 before reading the whole line, and we # therefore cannot know the proper response_protocol. - buf.append(b'Connection: close\r\n') + buf.append(b"Connection: close\r\n") else: # HTTP/1.0 had no 413/414 status nor Connection header. # Emit 400 instead and trust the message body is enough. - status = '400 Bad Request' + status = "400 Bad Request" buf.append(CRLF) if msg: if isinstance(msg, str): - msg = msg.encode('ISO-8859-1') + msg = msg.encode("ISO-8859-1") buf.append(msg) try: @@ -1130,7 +1160,7 @@ def ensure_headers_sent(self): def write(self, chunk): """Write unbuffered data to the client.""" if self.chunked_write and chunk: - chunk_size_hex = hex(len(chunk))[2:].encode('ascii') + chunk_size_hex = hex(len(chunk))[2:].encode("ascii") buf = [chunk_size_hex, CRLF, chunk, CRLF] self.conn.wfile.write(EMPTY.join(buf)) else: @@ -1148,7 +1178,7 @@ def send_headers(self): # noqa: C901 # FIXME if status == 413: # Request Entity Too Large. Close conn to avoid garbage. self.close_connection = True - elif b'content-length' not in hkeys: + elif b"content-length" not in hkeys: # "All 1xx (informational), 204 (no content), # and 304 (not modified) responses MUST NOT # include a message-body." So no point chunking. @@ -1156,13 +1186,12 @@ def send_headers(self): # noqa: C901 # FIXME pass else: needs_chunked = ( - self.response_protocol == 'HTTP/1.1' - and self.method != b'HEAD' + self.response_protocol == "HTTP/1.1" and self.method != b"HEAD" ) if needs_chunked: # Use the chunked transfer-coding self.chunked_write = True - self.outheaders.append((b'Transfer-Encoding', b'chunked')) + self.outheaders.append((b"Transfer-Encoding", b"chunked")) else: # Closing the conn is the only way to determine len. self.close_connection = True @@ -1173,23 +1202,25 @@ def send_headers(self): # noqa: C901 # FIXME can_keep = self.server.can_add_keepalive_connection self.close_connection = not can_keep - if b'connection' not in hkeys: - if self.response_protocol == 'HTTP/1.1': + if b"connection" not in hkeys: + if self.response_protocol == "HTTP/1.1": # Both server and client are HTTP/1.1 or better if self.close_connection: - self.outheaders.append((b'Connection', b'close')) + self.outheaders.append((b"Connection", b"close")) else: # Server and/or client are HTTP/1.0 if not self.close_connection: - self.outheaders.append((b'Connection', b'Keep-Alive')) + self.outheaders.append((b"Connection", b"Keep-Alive")) - if (b'Connection', b'Keep-Alive') in self.outheaders: - self.outheaders.append(( - b'Keep-Alive', - u'timeout={connection_timeout}'. - format(connection_timeout=self.server.timeout). - encode('ISO-8859-1'), - )) + if (b"Connection", b"Keep-Alive") in self.outheaders: + self.outheaders.append( + ( + b"Keep-Alive", + "timeout={connection_timeout}".format( + connection_timeout=self.server.timeout + ).encode("ISO-8859-1"), + ) + ) if (not self.close_connection) and (not self.chunked_read): # Read any remaining request body data on the socket. @@ -1204,23 +1235,27 @@ def send_headers(self): # noqa: C901 # FIXME # requirement is not be construed as preventing a server from # defending itself against denial-of-service attacks, or from # badly broken client implementations." - remaining = getattr(self.rfile, 'remaining', 0) + remaining = getattr(self.rfile, "remaining", 0) if remaining > 0: self.rfile.read(remaining) - if b'date' not in hkeys: - self.outheaders.append(( - b'Date', - email.utils.formatdate(usegmt=True).encode('ISO-8859-1'), - )) + if b"date" not in hkeys: + self.outheaders.append( + ( + b"Date", + email.utils.formatdate(usegmt=True).encode("ISO-8859-1"), + ) + ) - if b'server' not in hkeys: - self.outheaders.append(( - b'Server', - self.server.server_name.encode('ISO-8859-1'), - )) + if b"server" not in hkeys: + self.outheaders.append( + ( + b"Server", + self.server.server_name.encode("ISO-8859-1"), + ) + ) - proto = self.server.protocol.encode('ascii') + proto = self.server.protocol.encode("ascii") buf = [proto + SPACE + self.status + CRLF] for k, v in self.outheaders: buf.append(k + COLON + SPACE + v + CRLF) @@ -1254,8 +1289,8 @@ def __init__(self, server, sock, makefile=MakeFile): """ self.server = server self.socket = sock - self.rfile = makefile(sock, 'rb', self.rbufsize) - self.wfile = makefile(sock, 'wb', self.wbufsize) + self.rfile = makefile(sock, "rb", self.rbufsize) + self.wfile = makefile(sock, "wb", self.wbufsize) self.requests_seen = 0 self.peercreds_enabled = self.server.peercreds_enabled @@ -1263,12 +1298,8 @@ def __init__(self, server, sock, makefile=MakeFile): # LRU cached methods: # Ref: https://stackoverflow.com/a/14946506/595220 - self.resolve_peer_creds = ( - lru_cache(maxsize=1)(self.resolve_peer_creds) - ) - self.get_peer_creds = ( - lru_cache(maxsize=1)(self.get_peer_creds) - ) + self.resolve_peer_creds = lru_cache(maxsize=1)(self.resolve_peer_creds) + self.get_peer_creds = lru_cache(maxsize=1)(self.get_peer_creds) def communicate(self): # noqa: C901 # FIXME """Read each request and respond appropriately. @@ -1279,7 +1310,7 @@ def communicate(self): # noqa: C901 # FIXME try: req = self.RequestHandlerClass(self.server, self) req.parse_request() - if self.server.stats['Enabled']: + if self.server.stats["Enabled"]: self.requests_seen += 1 if not req.ready: # Something went wrong in the parsing (and the server has @@ -1294,20 +1325,21 @@ def communicate(self): # noqa: C901 # FIXME except socket.error as ex: errnum = ex.args[0] # sadly SSL sockets return a different (longer) time out string - timeout_errs = 'timed out', 'The read operation timed out' + timeout_errs = "timed out", "The read operation timed out" if errnum in timeout_errs: # Don't error if we're between requests; only error # if 1) no request has been started at all, or 2) we're # in the middle of a request. # See https://github.com/cherrypy/cherrypy/issues/853 if (not request_seen) or (req and req.started_request): - self._conditional_error(req, '408 Request Timeout') + self._conditional_error(req, "408 Request Timeout") elif errnum not in errors.socket_errors_to_ignore: self.server.error_log( - 'socket.error %s' % repr(errnum), - level=logging.WARNING, traceback=True, + "socket.error %s" % repr(errnum), + level=logging.WARNING, + traceback=True, ) - self._conditional_error(req, '500 Internal Server Error') + self._conditional_error(req, "500 Internal Server Error") except (KeyboardInterrupt, SystemExit): raise except errors.FatalSSLAlert: @@ -1316,9 +1348,11 @@ def communicate(self): # noqa: C901 # FIXME self._handle_no_ssl(req) except Exception as ex: self.server.error_log( - repr(ex), level=logging.ERROR, traceback=True, + repr(ex), + level=logging.ERROR, + traceback=True, ) - self._conditional_error(req, '500 Internal Server Error') + self._conditional_error(req, "500 Internal Server Error") return False linger = False @@ -1332,12 +1366,12 @@ def _handle_no_ssl(self, req): except AttributeError: # self.socket is of OpenSSL.SSL.Connection type resp_sock = self.socket._socket - self.wfile = StreamWriter(resp_sock, 'wb', self.wbufsize) + self.wfile = StreamWriter(resp_sock, "wb", self.wbufsize) msg = ( - 'The client sent a plain HTTP request, but ' - 'this server only speaks HTTPS on this port.' + "The client sent a plain HTTP request, but " + "this server only speaks HTTPS on this port." ) - req.simple_response('400 Bad Request', msg) + req.simple_response("400 Bad Request", msg) self.linger = True def _conditional_error(self, req, response): @@ -1387,22 +1421,23 @@ def get_peer_creds(self): # LRU cached on per-instance basis, see __init__ RuntimeError: in case of SO_PEERCRED lookup unsupported or disabled """ - PEERCRED_STRUCT_DEF = '3i' + PEERCRED_STRUCT_DEF = "3i" if IS_WINDOWS or self.socket.family != socket.AF_UNIX: raise NotImplementedError( - 'SO_PEERCRED is only supported in Linux kernel and WSL', + "SO_PEERCRED is only supported in Linux kernel and WSL", ) elif not self.peercreds_enabled: raise RuntimeError( - 'Peer creds lookup is disabled within this server', + "Peer creds lookup is disabled within this server", ) try: peer_creds = self.socket.getsockopt( # FIXME: Use LOCAL_CREDS for BSD-like OSs # Ref: https://gist.github.com/LucaFilipozzi/e4f1e118202aff27af6aadebda1b5d91 # noqa - socket.SOL_SOCKET, socket.SO_PEERCRED, + socket.SOL_SOCKET, + socket.SO_PEERCRED, struct.calcsize(PEERCRED_STRUCT_DEF), ) except socket.error as socket_err: @@ -1446,13 +1481,13 @@ def resolve_peer_creds(self): # LRU cached on per-instance basis """ if not IS_UID_GID_RESOLVABLE: raise NotImplementedError( - 'UID/GID lookup is unavailable under current platform. ' - 'It can only be done under UNIX-like OS ' - 'but not under the Google App Engine', + "UID/GID lookup is unavailable under current platform. " + "It can only be done under UNIX-like OS " + "but not under the Google App Engine", ) elif not self.peercreds_resolve_enabled: raise RuntimeError( - 'UID/GID lookup is disabled within this server', + "UID/GID lookup is disabled within this server", ) user = pwd.getpwuid(self.peer_uid).pw_name # [0] @@ -1476,7 +1511,8 @@ def _close_kernel_socket(self): """Terminate the connection at the transport level.""" # Honor ``sock_shutdown`` for PyOpenSSL connections. shutdown = getattr( - self.socket, 'sock_shutdown', + self.socket, + "sock_shutdown", self.socket.shutdown, ) @@ -1492,7 +1528,7 @@ def _close_kernel_socket(self): class HTTPServer: """An HTTP server.""" - _bind_addr = '127.0.0.1' + _bind_addr = "127.0.0.1" _interrupt = None gateway = None @@ -1509,7 +1545,7 @@ class HTTPServer: server_name = None """The name of the server; defaults to ``self.version``.""" - protocol = 'HTTP/1.1' + protocol = "HTTP/1.1" """The version string to write in the Status-Line of all HTTP responses. For example, "HTTP/1.1" is the default. This also limits the supported @@ -1533,7 +1569,7 @@ class HTTPServer: expired connections (default 0.5). """ - version = 'Cheroot/{version!s}'.format(version=__version__) + version = "Cheroot/{version!s}".format(version=__version__) """A version string for the HTTPServer.""" software = None @@ -1585,9 +1621,14 @@ class HTTPServer: Default is 10. Set to None to have unlimited connections.""" def __init__( - self, bind_addr, gateway, - minthreads=10, maxthreads=-1, server_name=None, - peercreds_enabled=False, peercreds_resolve_enabled=False, + self, + bind_addr, + gateway, + minthreads=10, + maxthreads=-1, + server_name=None, + peercreds_enabled=False, + peercreds_resolve_enabled=False, reuse_port=False, ): """Initialize HTTPServer instance. @@ -1606,16 +1647,16 @@ def __init__( self.gateway = gateway self.requests = threadpool.ThreadPool( - self, min=minthreads or 1, max=maxthreads, + self, + min=minthreads or 1, + max=maxthreads, ) if not server_name: server_name = self.version self.server_name = server_name self.peercreds_enabled = peercreds_enabled - self.peercreds_resolve_enabled = ( - peercreds_resolve_enabled and peercreds_enabled - ) + self.peercreds_resolve_enabled = peercreds_resolve_enabled and peercreds_enabled self.reuse_port = reuse_port self.clear_stats() @@ -1624,43 +1665,60 @@ def clear_stats(self): self._start_time = None self._run_time = 0 self.stats = { - 'Enabled': False, - 'Bind Address': lambda s: repr(self.bind_addr), - 'Run time': lambda s: (not s['Enabled']) and -1 or self.runtime(), - 'Accepts': 0, - 'Accepts/sec': lambda s: s['Accepts'] / self.runtime(), - 'Queue': lambda s: getattr(self.requests, 'qsize', None), - 'Threads': lambda s: len(getattr(self.requests, '_threads', [])), - 'Threads Idle': lambda s: getattr(self.requests, 'idle', None), - 'Socket Errors': 0, - 'Requests': lambda s: (not s['Enabled']) and -1 or sum( - (w['Requests'](w) for w in s['Worker Threads'].values()), 0, + "Enabled": False, + "Bind Address": lambda s: repr(self.bind_addr), + "Run time": lambda s: (not s["Enabled"]) and -1 or self.runtime(), + "Accepts": 0, + "Accepts/sec": lambda s: s["Accepts"] / self.runtime(), + "Queue": lambda s: getattr(self.requests, "qsize", None), + "Threads": lambda s: len(getattr(self.requests, "_threads", [])), + "Threads Idle": lambda s: getattr(self.requests, "idle", None), + "Socket Errors": 0, + "Requests": lambda s: (not s["Enabled"]) + and -1 + or sum( + (w["Requests"](w) for w in s["Worker Threads"].values()), + 0, ), - 'Bytes Read': lambda s: (not s['Enabled']) and -1 or sum( - (w['Bytes Read'](w) for w in s['Worker Threads'].values()), 0, + "Bytes Read": lambda s: (not s["Enabled"]) + and -1 + or sum( + (w["Bytes Read"](w) for w in s["Worker Threads"].values()), + 0, ), - 'Bytes Written': lambda s: (not s['Enabled']) and -1 or sum( - (w['Bytes Written'](w) for w in s['Worker Threads'].values()), + "Bytes Written": lambda s: (not s["Enabled"]) + and -1 + or sum( + (w["Bytes Written"](w) for w in s["Worker Threads"].values()), 0, ), - 'Work Time': lambda s: (not s['Enabled']) and -1 or sum( - (w['Work Time'](w) for w in s['Worker Threads'].values()), 0, + "Work Time": lambda s: (not s["Enabled"]) + and -1 + or sum( + (w["Work Time"](w) for w in s["Worker Threads"].values()), + 0, ), - 'Read Throughput': lambda s: (not s['Enabled']) and -1 or sum( + "Read Throughput": lambda s: (not s["Enabled"]) + and -1 + or sum( ( - w['Bytes Read'](w) / (w['Work Time'](w) or 1e-6) - for w in s['Worker Threads'].values() - ), 0, + w["Bytes Read"](w) / (w["Work Time"](w) or 1e-6) + for w in s["Worker Threads"].values() + ), + 0, ), - 'Write Throughput': lambda s: (not s['Enabled']) and -1 or sum( + "Write Throughput": lambda s: (not s["Enabled"]) + and -1 + or sum( ( - w['Bytes Written'](w) / (w['Work Time'](w) or 1e-6) - for w in s['Worker Threads'].values() - ), 0, + w["Bytes Written"](w) / (w["Work Time"](w) or 1e-6) + for w in s["Worker Threads"].values() + ), + 0, ), - 'Worker Threads': {}, + "Worker Threads": {}, } - logging.statistics['Cheroot HTTPServer %d' % id(self)] = self.stats + logging.statistics["Cheroot HTTPServer %d" % id(self)] = self.stats def runtime(self): """Return server uptime.""" @@ -1671,8 +1729,9 @@ def runtime(self): def __str__(self): """Render Server instance representing bind address.""" - return '%s.%s(%r)' % ( - self.__module__, self.__class__.__name__, + return "%s.%s(%r)" % ( + self.__module__, + self.__class__.__name__, self.bind_addr, ) @@ -1707,7 +1766,7 @@ def bind_addr(self): @bind_addr.setter def bind_addr(self, value): """Set the interface on which to listen for connections.""" - if isinstance(value, tuple) and value[0] in ('', None): + if isinstance(value, tuple) and value[0] in ("", None): # Despite the socket module docs, using '' does not # allow AI_PASSIVE to work. Passing None instead # returns '0.0.0.0' like we want. In other words: @@ -1721,7 +1780,7 @@ def bind_addr(self, value): raise ValueError( "Host values of '' or None are not allowed. " "Use '0.0.0.0' (IPv4) or '::' (IPv6) instead " - 'to listen on all active interfaces.', + "to listen on all active interfaces.", ) self._bind_addr = value @@ -1749,12 +1808,12 @@ def prepare(self): # noqa: C901 # FIXME self._interrupt = None if self.software is None: - self.software = '%s Server' % self.version + self.software = "%s Server" % self.version # Select the appropriate socket self.socket = None - msg = 'No socket could be created' - if os.getenv('LISTEN_PID', None): + msg = "No socket could be created" + if os.getenv("LISTEN_PID", None): # systemd socket activation self.socket = socket.fromfd(3, socket.AF_INET, socket.SOCK_STREAM) elif isinstance(self.bind_addr, (str, bytes)): @@ -1762,7 +1821,7 @@ def prepare(self): # noqa: C901 # FIXME try: self.bind_unix_socket(self.bind_addr) except socket.error as serr: - msg = '%s -- (%s: %s)' % (msg, self.bind_addr, serr) + msg = "%s -- (%s: %s)" % (msg, self.bind_addr, serr) raise socket.error(msg) from serr else: # AF_INET or AF_INET6 socket @@ -1771,18 +1830,22 @@ def prepare(self): # noqa: C901 # FIXME host, port = self.bind_addr try: info = socket.getaddrinfo( - host, port, socket.AF_UNSPEC, - socket.SOCK_STREAM, 0, socket.AI_PASSIVE, + host, + port, + socket.AF_UNSPEC, + socket.SOCK_STREAM, + 0, + socket.AI_PASSIVE, ) except socket.gaierror: sock_type = socket.AF_INET bind_addr = self.bind_addr - if ':' in host: + if ":" in host: sock_type = socket.AF_INET6 bind_addr = bind_addr + (0, 0) - info = [(sock_type, socket.SOCK_STREAM, 0, '', bind_addr)] + info = [(sock_type, socket.SOCK_STREAM, 0, "", bind_addr)] for res in info: af, socktype, proto, _canonname, sa = res @@ -1790,7 +1853,7 @@ def prepare(self): # noqa: C901 # FIXME self.bind(af, socktype, proto) break except socket.error as serr: - msg = '%s -- (%s: %s)' % (msg, sa, serr) + msg = "%s -- (%s: %s)" % (msg, sa, serr) if self.socket: self.socket.close() self.socket = None @@ -1820,7 +1883,8 @@ def serve(self): raise except Exception: self.error_log( - 'Error in HTTPServer.serve', level=logging.ERROR, + "Error in HTTPServer.serve", + level=logging.ERROR, traceback=True, ) @@ -1869,7 +1933,7 @@ def put_conn(self, conn): # server is shutting down, just close it conn.close() - def error_log(self, msg='', level=20, traceback=False): + def error_log(self, msg="", level=20, traceback=False): """Write error message to log. Args: @@ -1878,7 +1942,7 @@ def error_log(self, msg='', level=20, traceback=False): traceback (bool): add traceback to output or not """ # Override this in subclasses as desired - sys.stderr.write('{msg!s}\n'.format(msg=msg)) + sys.stderr.write("{msg!s}\n".format(msg=msg)) sys.stderr.flush() if traceback: tblines = traceback_.format_exc() @@ -1889,8 +1953,11 @@ def bind(self, family, type, proto=0): """Create (or recreate) the actual socket object.""" sock = self.prepare_socket( self.bind_addr, - family, type, proto, - self.nodelay, self.ssl_adapter, + family, + type, + proto, + self.nodelay, + self.ssl_adapter, self.reuse_port, ) sock = self.socket = self.bind_socket(sock, self.bind_addr) @@ -1905,7 +1972,7 @@ def bind_unix_socket(self, bind_addr): # noqa: C901 # FIXME causes an AttributeError. """ raise ValueError( # or RuntimeError? - 'AF_UNIX sockets are not supported under Windows.', + "AF_UNIX sockets are not supported under Windows.", ) fs_permissions = 0o777 # TODO: allow changing mode @@ -1920,26 +1987,27 @@ def bind_unix_socket(self, bind_addr): # noqa: C901 # FIXME except TypeError as typ_err: err_msg = str(typ_err) if ( - 'remove() argument 1 must be encoded ' - 'string without null bytes, not unicode' - not in err_msg + "remove() argument 1 must be encoded " + "string without null bytes, not unicode" not in err_msg ): raise except ValueError as val_err: err_msg = str(val_err) if ( - 'unlink: embedded null ' - 'character in path' not in err_msg - and 'embedded null byte' not in err_msg - and 'argument must be a ' - 'string without NUL characters' not in err_msg # pypy3 + "unlink: embedded null " "character in path" not in err_msg + and "embedded null byte" not in err_msg + and "argument must be a " "string without NUL characters" + not in err_msg # pypy3 ): raise sock = self.prepare_socket( bind_addr=bind_addr, - family=socket.AF_UNIX, type=socket.SOCK_STREAM, proto=0, - nodelay=self.nodelay, ssl_adapter=self.ssl_adapter, + family=socket.AF_UNIX, + type=socket.SOCK_STREAM, + proto=0, + nodelay=self.nodelay, + ssl_adapter=self.ssl_adapter, reuse_port=self.reuse_port, ) @@ -1972,7 +2040,7 @@ def bind_unix_socket(self, bind_addr): # noqa: C901 # FIXME if not FS_PERMS_SET: self.error_log( - 'Failed to set socket fs mode permissions', + "Failed to set socket fs mode permissions", level=logging.WARNING, ) @@ -1986,30 +2054,36 @@ def _make_socket_reusable(socket_, bind_addr): IS_EPHEMERAL_PORT = port == 0 if socket_.family not in (socket.AF_INET, socket.AF_INET6): - raise ValueError('Cannot reuse a non-IP socket') + raise ValueError("Cannot reuse a non-IP socket") if IS_EPHEMERAL_PORT: - raise ValueError('Cannot reuse an ephemeral port (0)') + raise ValueError("Cannot reuse an ephemeral port (0)") # Most BSD kernels implement SO_REUSEPORT the way that only the # latest listener can read from socket. Some of BSD kernels also # have SO_REUSEPORT_LB that works similarly to SO_REUSEPORT # in Linux. - if hasattr(socket, 'SO_REUSEPORT_LB'): + if hasattr(socket, "SO_REUSEPORT_LB"): socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT_LB, 1) - elif hasattr(socket, 'SO_REUSEPORT'): + elif hasattr(socket, "SO_REUSEPORT"): socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) elif IS_WINDOWS: socket_.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) else: raise NotImplementedError( - 'Current platform does not support port reuse', + "Current platform does not support port reuse", ) @classmethod def prepare_socket( - cls, bind_addr, family, type, proto, nodelay, ssl_adapter, - reuse_port=False, + cls, + bind_addr, + family, + type, + proto, + nodelay, + ssl_adapter, + reuse_port=False, ): """Create and prepare the socket object.""" sock = socket.socket(family, type, proto) @@ -2043,14 +2117,16 @@ def prepare_socket( # activate dual-stack. See # https://github.com/cherrypy/cherrypy/issues/871. listening_ipv6 = ( - hasattr(socket, 'AF_INET6') + hasattr(socket, "AF_INET6") and family == socket.AF_INET6 - and host in ('::', '::0', '::0.0.0.0') + and host in ("::", "::0", "::0.0.0.0") ) if listening_ipv6: try: sock.setsockopt( - socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0, + socket.IPPROTO_IPV6, + socket.IPV6_V6ONLY, + 0, ) except (AttributeError, socket.error): # Apparently, the socket option is not available in @@ -2117,10 +2193,10 @@ def interrupt(self, interrupt): self._interrupt = _STOPPING_FOR_INTERRUPT if isinstance(interrupt, KeyboardInterrupt): - self.error_log('Keyboard Interrupt: shutting down') + self.error_log("Keyboard Interrupt: shutting down") if isinstance(interrupt, SystemExit): - self.error_log('SystemExit raised: shutting down') + self.error_log("SystemExit raised: shutting down") self.stop() self._interrupt = interrupt @@ -2132,12 +2208,12 @@ def stop(self): # noqa: C901 # FIXME self.ready = False if self._start_time is not None: - self._run_time += (time.time() - self._start_time) + self._run_time += time.time() - self._start_time self._start_time = None self._connections.stop() - sock = getattr(self, 'socket', None) + sock = getattr(self, "socket", None) if sock: if not isinstance(self.bind_addr, (str, bytes)): # Touch our own socket to make accept() return immediately. @@ -2155,7 +2231,9 @@ def stop(self): # noqa: C901 # FIXME # localhost won't work if we've bound to a public IP, # but it will if we bound to '0.0.0.0' (INADDR_ANY). for res in socket.getaddrinfo( - host, port, socket.AF_UNSPEC, + host, + port, + socket.AF_UNSPEC, socket.SOCK_STREAM, ): af, socktype, proto, _canonname, _sa = res @@ -2171,7 +2249,7 @@ def stop(self): # noqa: C901 # FIXME except socket.error: if s: s.close() - if hasattr(sock, 'close'): + if hasattr(sock, "close"): sock.close() self.socket = None @@ -2198,17 +2276,17 @@ def respond(self): # These may either be ssl.Adapter subclasses or the string names # of such classes (in which case they will be lazily loaded). ssl_adapters = { - 'builtin': 'cheroot.ssl.builtin.BuiltinSSLAdapter', - 'pyopenssl': 'cheroot.ssl.pyopenssl.pyOpenSSLAdapter', + "builtin": "cheroot.ssl.builtin.BuiltinSSLAdapter", + "pyopenssl": "cheroot.ssl.pyopenssl.pyOpenSSLAdapter", } -def get_ssl_adapter_class(name='builtin'): +def get_ssl_adapter_class(name="builtin"): """Return an SSL adapter class for the given name.""" adapter = ssl_adapters[name.lower()] if isinstance(adapter, str): - last_dot = adapter.rfind('.') - attr_name = adapter[last_dot + 1:] + last_dot = adapter.rfind(".") + attr_name = adapter[last_dot + 1 :] mod_path = adapter[:last_dot] try: @@ -2217,15 +2295,14 @@ def get_ssl_adapter_class(name='builtin'): raise KeyError() except KeyError: # The last [''] is important. - mod = __import__(mod_path, globals(), locals(), ['']) + mod = __import__(mod_path, globals(), locals(), [""]) # Let an AttributeError propagate outward. try: adapter = getattr(mod, attr_name) except AttributeError: raise AttributeError( - "'%s' object has no attribute '%s'" - % (mod_path, attr_name), + "'%s' object has no attribute '%s'" % (mod_path, attr_name), ) return adapter diff --git a/cheroot/server.pyi b/cheroot/server.pyi index ecbe2f2758..200af7b168 100644 --- a/cheroot/server.pyi +++ b/cheroot/server.pyi @@ -61,7 +61,9 @@ class HTTPRequest: chunked_read: bool proxy_mode: Any strict_mode: Any - def __init__(self, server, conn, proxy_mode: bool = ..., strict_mode: bool = ...) -> None: ... + def __init__( + self, server, conn, proxy_mode: bool = ..., strict_mode: bool = ... + ) -> None: ... rfile: Any def parse_request(self) -> None: ... uri: Any @@ -133,7 +135,17 @@ class HTTPServer: reuse_port: bool keep_alive_conn_limit: int requests: Any - def __init__(self, bind_addr, gateway, minthreads: int = ..., maxthreads: int = ..., server_name: Any | None = ..., peercreds_enabled: bool = ..., peercreds_resolve_enabled: bool = ..., reuse_port: bool = ...) -> None: ... + def __init__( + self, + bind_addr, + gateway, + minthreads: int = ..., + maxthreads: int = ..., + server_name: Any | None = ..., + peercreds_enabled: bool = ..., + peercreds_resolve_enabled: bool = ..., + reuse_port: bool = ..., + ) -> None: ... stats: Any def clear_stats(self): ... def runtime(self): ... @@ -149,13 +161,24 @@ class HTTPServer: @property def can_add_keepalive_connection(self): ... def put_conn(self, conn) -> None: ... - def error_log(self, msg: str = ..., level: int = ..., traceback: bool = ...) -> None: ... + def error_log( + self, msg: str = ..., level: int = ..., traceback: bool = ... + ) -> None: ... def bind(self, family, type, proto: int = ...): ... def bind_unix_socket(self, bind_addr): ... @staticmethod def _make_socket_reusable(socket_, bind_addr) -> None: ... @classmethod - def prepare_socket(cls, bind_addr, family, type, proto, nodelay, ssl_adapter, reuse_port: bool = ...): ... + def prepare_socket( + cls, + bind_addr, + family, + type, + proto, + nodelay, + ssl_adapter, + reuse_port: bool = ..., + ): ... @staticmethod def bind_socket(socket_, bind_addr): ... @staticmethod diff --git a/cheroot/ssl/__init__.py b/cheroot/ssl/__init__.py index 19b587d0be..baae9ec556 100644 --- a/cheroot/ssl/__init__.py +++ b/cheroot/ssl/__init__.py @@ -15,8 +15,11 @@ class Adapter(metaclass=ABCMeta): @abstractmethod def __init__( - self, certificate, private_key, certificate_chain=None, - ciphers=None, + self, + certificate, + private_key, + certificate_chain=None, + ciphers=None, ): """Set up certificates, private key ciphers and reset context.""" self.certificate = certificate @@ -41,6 +44,6 @@ def get_environ(self): raise NotImplementedError # pragma: no cover @abstractmethod - def makefile(self, sock, mode='r', bufsize=-1): + def makefile(self, sock, mode="r", bufsize=-1): """Return socket file object.""" raise NotImplementedError # pragma: no cover diff --git a/cheroot/ssl/__init__.pyi b/cheroot/ssl/__init__.pyi index 4801fbdd6c..0274e9ebf6 100644 --- a/cheroot/ssl/__init__.pyi +++ b/cheroot/ssl/__init__.pyi @@ -8,7 +8,13 @@ class Adapter(metaclass=ABCMeta): ciphers: Any context: Any @abstractmethod - def __init__(self, certificate, private_key, certificate_chain: Any | None = ..., ciphers: Any | None = ...): ... + def __init__( + self, + certificate, + private_key, + certificate_chain: Any | None = ..., + ciphers: Any | None = ..., + ): ... @abstractmethod def bind(self, sock): ... @abstractmethod diff --git a/cheroot/ssl/builtin.py b/cheroot/ssl/builtin.py index e28e5df188..1774e93dd3 100644 --- a/cheroot/ssl/builtin.py +++ b/cheroot/ssl/builtin.py @@ -35,8 +35,7 @@ def _assert_ssl_exc_contains(exc, *msgs): """Check whether SSL exception contains either of messages provided.""" if len(msgs) < 1: raise TypeError( - '_assert_ssl_exc_contains() requires ' - 'at least one message to be passed.', + "_assert_ssl_exc_contains() requires " "at least one message to be passed.", ) err_msg_lower = str(exc).lower() return any(m.lower() in err_msg_lower for m in msgs) @@ -50,7 +49,9 @@ def _loopback_for_cert_thread(context, server): # https://github.com/cherrypy/cheroot/issues/302#issuecomment-662592030 with suppress(ssl.SSLError, OSError): with context.wrap_socket( - server, do_handshake_on_connect=True, server_side=True, + server, + do_handshake_on_connect=True, + server_side=True, ) as ssl_sock: # in TLS 1.3 (Python 3.7+, OpenSSL 1.1.1+), the server # sends the client session tickets that can be used to @@ -70,7 +71,7 @@ def _loopback_for_cert_thread(context, server): # tickets and close the connection cleanly. # Note that, as this is essentially a race condition, # the error may still occur ocasionally. - ssl_sock.send(b'0000') + ssl_sock.send(b"0000") def _loopback_for_cert(certificate, private_key, certificate_chain): @@ -90,13 +91,15 @@ def _loopback_for_cert(certificate, private_key, certificate_chain): # when `close` is called, the SSL shutdown notice will be sent # and then python will wait to receive the corollary shutdown. thread = threading.Thread( - target=_loopback_for_cert_thread, args=(context, server), + target=_loopback_for_cert_thread, + args=(context, server), ) try: thread.start() with context.wrap_socket( - client, do_handshake_on_connect=True, - server_side=False, + client, + do_handshake_on_connect=True, + server_side=False, ) as ssl_sock: ssl_sock.recv(4) return ssl_sock.getpeercert() @@ -150,13 +153,13 @@ class BuiltinSSLAdapter(Adapter): # from mod_ssl/pkg.sslmod/ssl_engine_vars.c ssl_var_lookup_ssl_cert CERT_KEY_TO_ENV = { - 'version': 'M_VERSION', - 'serialNumber': 'M_SERIAL', - 'notBefore': 'V_START', - 'notAfter': 'V_END', - 'subject': 'S_DN', - 'issuer': 'I_DN', - 'subjectAltName': 'SAN', + "version": "M_VERSION", + "serialNumber": "M_SERIAL", + "notBefore": "V_START", + "notAfter": "V_END", + "subject": "S_DN", + "issuer": "I_DN", + "subjectAltName": "SAN", # not parsed by the Python standard library # - A_SIG # - A_KEY @@ -168,21 +171,21 @@ class BuiltinSSLAdapter(Adapter): # from mod_ssl/pkg.sslmod/ssl_engine_vars.c ssl_var_lookup_ssl_cert_dn_rec CERT_KEY_TO_LDAP_CODE = { - 'countryName': 'C', - 'stateOrProvinceName': 'ST', + "countryName": "C", + "stateOrProvinceName": "ST", # NOTE: mod_ssl also provides 'stateOrProvinceName' as 'SP' # for compatibility with SSLeay - 'localityName': 'L', - 'organizationName': 'O', - 'organizationalUnitName': 'OU', - 'commonName': 'CN', - 'title': 'T', - 'initials': 'I', - 'givenName': 'G', - 'surname': 'S', - 'description': 'D', - 'userid': 'UID', - 'emailAddress': 'Email', + "localityName": "L", + "organizationName": "O", + "organizationalUnitName": "OU", + "commonName": "CN", + "title": "T", + "initials": "I", + "givenName": "G", + "surname": "S", + "description": "D", + "userid": "UID", + "emailAddress": "Email", # not provided by mod_ssl # - dnQualifier: DNQ # - domainComponent: DC @@ -198,15 +201,21 @@ class BuiltinSSLAdapter(Adapter): } def __init__( - self, certificate, private_key, certificate_chain=None, - ciphers=None, + self, + certificate, + private_key, + certificate_chain=None, + ciphers=None, ): """Set up context in addition to base class properties if available.""" if ssl is None: - raise ImportError('You must install the ssl module to use HTTPS.') + raise ImportError("You must install the ssl module to use HTTPS.") super(BuiltinSSLAdapter, self).__init__( - certificate, private_key, certificate_chain, ciphers, + certificate, + private_key, + certificate_chain, + ciphers, ) self.context = ssl.create_default_context( @@ -218,13 +227,13 @@ def __init__( self.context.set_ciphers(ciphers) self._server_env = self._make_env_cert_dict( - 'SSL_SERVER', + "SSL_SERVER", _parse_cert(certificate, private_key, self.certificate_chain), ) if not self._server_env: return cert = None - with open(certificate, mode='rt') as f: + with open(certificate, mode="rt") as f: cert = f.read() # strip off any keys by only taking the first certificate @@ -235,7 +244,7 @@ def __init__( if cert_end == -1: return cert_end += len(ssl.PEM_FOOTER) - self._server_env['SSL_SERVER_CERT'] = cert[cert_start:cert_end] + self._server_env["SSL_SERVER_CERT"] = cert[cert_start:cert_end] @property def context(self): @@ -264,7 +273,9 @@ def wrap(self, sock): """Wrap and return the given socket, plus WSGI environ entries.""" try: s = self.context.wrap_socket( - sock, do_handshake_on_connect=True, server_side=True, + sock, + do_handshake_on_connect=True, + server_side=True, ) except ( ssl.SSLEOFError, @@ -275,8 +286,8 @@ def wrap(self, sock): ) from tls_connection_drop_error except ssl.SSLError as generic_tls_error: peer_speaks_plain_http_over_https = ( - generic_tls_error.errno == ssl.SSL_ERROR_SSL and - _assert_ssl_exc_contains(generic_tls_error, 'http request') + generic_tls_error.errno == ssl.SSL_ERROR_SSL + and _assert_ssl_exc_contains(generic_tls_error, "http request") ) if peer_speaks_plain_http_over_https: reraised_connection_drop_exc_cls = errors.NoSSLError @@ -297,17 +308,19 @@ def get_environ(self, sock): """Create WSGI environ entries to be merged into each request.""" cipher = sock.cipher() ssl_environ = { - 'wsgi.url_scheme': 'https', - 'HTTPS': 'on', - 'SSL_PROTOCOL': cipher[1], - 'SSL_CIPHER': cipher[0], - 'SSL_CIPHER_EXPORT': '', - 'SSL_CIPHER_USEKEYSIZE': cipher[2], - 'SSL_VERSION_INTERFACE': '%s Python/%s' % ( - HTTPServer.version, sys.version, + "wsgi.url_scheme": "https", + "HTTPS": "on", + "SSL_PROTOCOL": cipher[1], + "SSL_CIPHER": cipher[0], + "SSL_CIPHER_EXPORT": "", + "SSL_CIPHER_USEKEYSIZE": cipher[2], + "SSL_VERSION_INTERFACE": "%s Python/%s" + % ( + HTTPServer.version, + sys.version, ), - 'SSL_VERSION_LIBRARY': ssl.OPENSSL_VERSION, - 'SSL_CLIENT_VERIFY': 'NONE', + "SSL_VERSION_LIBRARY": ssl.OPENSSL_VERSION, + "SSL_CLIENT_VERIFY": "NONE", # 'NONE' - client did not provide a cert (overriden below) } @@ -315,32 +328,32 @@ def get_environ(self, sock): with suppress(AttributeError): compression = sock.compression() if compression is not None: - ssl_environ['SSL_COMPRESS_METHOD'] = compression + ssl_environ["SSL_COMPRESS_METHOD"] = compression # Python 3.6+ with suppress(AttributeError): - ssl_environ['SSL_SESSION_ID'] = sock.session.id.hex() + ssl_environ["SSL_SESSION_ID"] = sock.session.id.hex() with suppress(AttributeError): target_cipher = cipher[:2] for cip in sock.context.get_ciphers(): - if target_cipher == (cip['name'], cip['protocol']): - ssl_environ['SSL_CIPHER_ALGKEYSIZE'] = cip['alg_bits'] + if target_cipher == (cip["name"], cip["protocol"]): + ssl_environ["SSL_CIPHER_ALGKEYSIZE"] = cip["alg_bits"] break # Python 3.7+ sni_callback with suppress(AttributeError): - ssl_environ['SSL_TLS_SNI'] = sock.sni + ssl_environ["SSL_TLS_SNI"] = sock.sni if self.context and self.context.verify_mode != ssl.CERT_NONE: client_cert = sock.getpeercert() if client_cert: # builtin ssl **ALWAYS** validates client certificates # and terminates the connection on failure - ssl_environ['SSL_CLIENT_VERIFY'] = 'SUCCESS' + ssl_environ["SSL_CLIENT_VERIFY"] = "SUCCESS" ssl_environ.update( - self._make_env_cert_dict('SSL_CLIENT', client_cert), + self._make_env_cert_dict("SSL_CLIENT", client_cert), ) - ssl_environ['SSL_CLIENT_CERT'] = ssl.DER_cert_to_PEM_cert( + ssl_environ["SSL_CLIENT_CERT"] = ssl.DER_cert_to_PEM_cert( sock.getpeercert(binary_form=True), ).strip() @@ -366,22 +379,22 @@ def _make_env_cert_dict(self, env_prefix, parsed_cert): env = {} for cert_key, env_var in self.CERT_KEY_TO_ENV.items(): - key = '%s_%s' % (env_prefix, env_var) + key = "%s_%s" % (env_prefix, env_var) value = parsed_cert.get(cert_key) - if env_var == 'SAN': + if env_var == "SAN": env.update(self._make_env_san_dict(key, value)) - elif env_var.endswith('_DN'): + elif env_var.endswith("_DN"): env.update(self._make_env_dn_dict(key, value)) else: env[key] = str(value) # mod_ssl 2.1+; Python 3.2+ # number of days until the certificate expires - if 'notBefore' in parsed_cert: - remain = ssl.cert_time_to_seconds(parsed_cert['notAfter']) - remain -= ssl.cert_time_to_seconds(parsed_cert['notBefore']) + if "notBefore" in parsed_cert: + remain = ssl.cert_time_to_seconds(parsed_cert["notAfter"]) + remain -= ssl.cert_time_to_seconds(parsed_cert["notBefore"]) remain /= 60 * 60 * 24 - env['%s_V_REMAIN' % (env_prefix,)] = str(int(remain)) + env["%s_V_REMAIN" % (env_prefix,)] = str(int(remain)) return env @@ -399,11 +412,11 @@ def _make_env_san_dict(self, env_prefix, cert_value): dns_count = 0 email_count = 0 for attr_name, val in cert_value: - if attr_name == 'DNS': - env['%s_DNS_%i' % (env_prefix, dns_count)] = val + if attr_name == "DNS": + env["%s_DNS_%i" % (env_prefix, dns_count)] = val dns_count += 1 - elif attr_name == 'Email': - env['%s_Email_%i' % (env_prefix, email_count)] = val + elif attr_name == "Email": + env["%s_Email_%i" % (env_prefix, email_count)] = val email_count += 1 # other mod_ssl SAN vars: @@ -425,24 +438,24 @@ def _make_env_dn_dict(self, env_prefix, cert_value): for rdn in cert_value: for attr_name, val in rdn: attr_code = self.CERT_KEY_TO_LDAP_CODE.get(attr_name) - dn.append('%s=%s' % (attr_code or attr_name, val)) + dn.append("%s=%s" % (attr_code or attr_name, val)) if not attr_code: continue dn_attrs.setdefault(attr_code, []) dn_attrs[attr_code].append(val) env = { - env_prefix: ','.join(dn), + env_prefix: ",".join(dn), } for attr_code, values in dn_attrs.items(): - env['%s_%s' % (env_prefix, attr_code)] = ','.join(values) + env["%s_%s" % (env_prefix, attr_code)] = ",".join(values) if len(values) == 1: continue for i, val in enumerate(values): - env['%s_%s_%i' % (env_prefix, attr_code, i)] = val + env["%s_%s_%i" % (env_prefix, attr_code, i)] = val return env - def makefile(self, sock, mode='r', bufsize=DEFAULT_BUFFER_SIZE): + def makefile(self, sock, mode="r", bufsize=DEFAULT_BUFFER_SIZE): """Return socket file object.""" - cls = StreamReader if 'r' in mode else StreamWriter + cls = StreamReader if "r" in mode else StreamWriter return cls(sock, mode, bufsize) diff --git a/cheroot/ssl/builtin.pyi b/cheroot/ssl/builtin.pyi index 72e4500179..7c3821f901 100644 --- a/cheroot/ssl/builtin.pyi +++ b/cheroot/ssl/builtin.pyi @@ -6,7 +6,13 @@ DEFAULT_BUFFER_SIZE: int class BuiltinSSLAdapter(Adapter): CERT_KEY_TO_ENV: Any CERT_KEY_TO_LDAP_CODE: Any - def __init__(self, certificate, private_key, certificate_chain: Any | None = ..., ciphers: Any | None = ...) -> None: ... + def __init__( + self, + certificate, + private_key, + certificate_chain: Any | None = ..., + ciphers: Any | None = ..., + ) -> None: ... @property def context(self): ... @context.setter diff --git a/cheroot/ssl/pyopenssl.py b/cheroot/ssl/pyopenssl.py index 8b01b348de..6182331471 100644 --- a/cheroot/ssl/pyopenssl.py +++ b/cheroot/ssl/pyopenssl.py @@ -76,7 +76,7 @@ class SSLFileobjectMixin: """Base mixin for a TLS socket stream.""" ssl_timeout = 3 - ssl_retry = .01 + ssl_retry = 0.01 # FIXME: def _safe_call(self, is_reader, call, *args, **kwargs): # noqa: C901 @@ -99,16 +99,16 @@ def _safe_call(self, is_reader, call, *args, **kwargs): # noqa: C901 except SSL.WantWriteError: time.sleep(self.ssl_retry) except SSL.SysCallError as e: - if is_reader and e.args == (-1, 'Unexpected EOF'): - return b'' + if is_reader and e.args == (-1, "Unexpected EOF"): + return b"" errnum = e.args[0] if is_reader and errnum in errors.socket_errors_to_ignore: - return b'' + return b"" raise socket.error(errnum) except SSL.Error as e: - if is_reader and e.args == (-1, 'Unexpected EOF'): - return b'' + if is_reader and e.args == (-1, "Unexpected EOF"): + return b"" thirdarg = None try: @@ -116,14 +116,14 @@ def _safe_call(self, is_reader, call, *args, **kwargs): # noqa: C901 except IndexError: pass - if thirdarg == 'http request': + if thirdarg == "http request": # The client is talking HTTP to an HTTPS server. raise errors.NoSSLError() raise errors.FatalSSLAlert(*e.args) if time.time() - start > self.ssl_timeout: - raise socket.timeout('timed out') + raise socket.timeout("timed out") def recv(self, size): """Receive message of a size from the socket.""" @@ -150,7 +150,8 @@ def sendall(self, *args, **kwargs): return self._safe_call( False, super(SSLFileobjectMixin, self).sendall, - *args, **kwargs, + *args, + **kwargs, ) def send(self, *args, **kwargs): @@ -158,7 +159,8 @@ def send(self, *args, **kwargs): return self._safe_call( False, super(SSLFileobjectMixin, self).send, - *args, **kwargs, + *args, + **kwargs, ) @@ -176,23 +178,44 @@ class SSLConnectionProxyMeta: def __new__(mcl, name, bases, nmspc): """Attach a list of proxy methods to a new class.""" proxy_methods = ( - 'get_context', 'pending', 'send', 'write', 'recv', 'read', - 'renegotiate', 'bind', 'listen', 'connect', 'accept', - 'setblocking', 'fileno', 'close', 'get_cipher_list', - 'getpeername', 'getsockname', 'getsockopt', 'setsockopt', - 'makefile', 'get_app_data', 'set_app_data', 'state_string', - 'sock_shutdown', 'get_peer_certificate', 'want_read', - 'want_write', 'set_connect_state', 'set_accept_state', - 'connect_ex', 'sendall', 'settimeout', 'gettimeout', - 'shutdown', - ) - proxy_methods_no_args = ( - 'shutdown', + "get_context", + "pending", + "send", + "write", + "recv", + "read", + "renegotiate", + "bind", + "listen", + "connect", + "accept", + "setblocking", + "fileno", + "close", + "get_cipher_list", + "getpeername", + "getsockname", + "getsockopt", + "setsockopt", + "makefile", + "get_app_data", + "set_app_data", + "state_string", + "sock_shutdown", + "get_peer_certificate", + "want_read", + "want_write", + "set_connect_state", + "set_accept_state", + "connect_ex", + "sendall", + "settimeout", + "gettimeout", + "shutdown", ) + proxy_methods_no_args = ("shutdown",) - proxy_props = ( - 'family', - ) + proxy_props = ("family",) def lock_decorator(method): """Create a proxy method for a new class.""" @@ -200,13 +223,13 @@ def lock_decorator(method): def proxy_wrapper(self, *args): self._lock.acquire() try: - new_args = ( - args[:] if method not in proxy_methods_no_args else [] - ) + new_args = args[:] if method not in proxy_methods_no_args else [] return getattr(self._ssl_conn, method)(*new_args) finally: self._lock.release() + return proxy_wrapper + for m in proxy_methods: nmspc[m] = lock_decorator(m) nmspc[m].__name__ = m @@ -216,8 +239,10 @@ def make_property(property_): def proxy_prop_wrapper(self): return getattr(self._ssl_conn, property_) + proxy_prop_wrapper.__name__ = property_ return property(proxy_prop_wrapper) + for p in proxy_props: nmspc[p] = make_property(p) @@ -264,15 +289,21 @@ class pyOpenSSLAdapter(Adapter): """The ciphers list of TLS.""" def __init__( - self, certificate, private_key, certificate_chain=None, - ciphers=None, + self, + certificate, + private_key, + certificate_chain=None, + ciphers=None, ): """Initialize OpenSSL Adapter instance.""" if SSL is None: - raise ImportError('You must install pyOpenSSL to use HTTPS.') + raise ImportError("You must install pyOpenSSL to use HTTPS.") super(pyOpenSSLAdapter, self).__init__( - certificate, private_key, certificate_chain, ciphers, + certificate, + private_key, + certificate_chain, + ciphers, ) self._environ = None @@ -308,66 +339,67 @@ def get_context(self): def get_environ(self): """Return WSGI environ entries to be merged into each request.""" ssl_environ = { - 'wsgi.url_scheme': 'https', - 'HTTPS': 'on', - 'SSL_VERSION_INTERFACE': '%s %s/%s Python/%s' % ( + "wsgi.url_scheme": "https", + "HTTPS": "on", + "SSL_VERSION_INTERFACE": "%s %s/%s Python/%s" + % ( cheroot_server.HTTPServer.version, - OpenSSL.version.__title__, OpenSSL.version.__version__, + OpenSSL.version.__title__, + OpenSSL.version.__version__, sys.version, ), - 'SSL_VERSION_LIBRARY': SSL.SSLeay_version( + "SSL_VERSION_LIBRARY": SSL.SSLeay_version( SSL.SSLEAY_VERSION, ).decode(), } if self.certificate: # Server certificate attributes - with open(self.certificate, 'rb') as cert_file: + with open(self.certificate, "rb") as cert_file: cert = crypto.load_certificate( - crypto.FILETYPE_PEM, cert_file.read(), + crypto.FILETYPE_PEM, + cert_file.read(), ) - ssl_environ.update({ - 'SSL_SERVER_M_VERSION': cert.get_version(), - 'SSL_SERVER_M_SERIAL': cert.get_serial_number(), - # 'SSL_SERVER_V_START': - # Validity of server's certificate (start time), - # 'SSL_SERVER_V_END': - # Validity of server's certificate (end time), - }) + ssl_environ.update( + { + "SSL_SERVER_M_VERSION": cert.get_version(), + "SSL_SERVER_M_SERIAL": cert.get_serial_number(), + # 'SSL_SERVER_V_START': + # Validity of server's certificate (start time), + # 'SSL_SERVER_V_END': + # Validity of server's certificate (end time), + } + ) for prefix, dn in [ - ('I', cert.get_issuer()), - ('S', cert.get_subject()), + ("I", cert.get_issuer()), + ("S", cert.get_subject()), ]: # X509Name objects don't seem to have a way to get the # complete DN string. Use str() and slice it instead, # because str(dn) == "" dnstr = str(dn)[18:-2] - wsgikey = 'SSL_SERVER_%s_DN' % prefix + wsgikey = "SSL_SERVER_%s_DN" % prefix ssl_environ[wsgikey] = dnstr # The DN should be of the form: /k1=v1/k2=v2, but we must allow # for any value to contain slashes itself (in a URL). while dnstr: - pos = dnstr.rfind('=') - dnstr, value = dnstr[:pos], dnstr[pos + 1:] - pos = dnstr.rfind('/') - dnstr, key = dnstr[:pos], dnstr[pos + 1:] + pos = dnstr.rfind("=") + dnstr, value = dnstr[:pos], dnstr[pos + 1 :] + pos = dnstr.rfind("/") + dnstr, key = dnstr[:pos], dnstr[pos + 1 :] if key and value: - wsgikey = 'SSL_SERVER_%s_DN_%s' % (prefix, key) + wsgikey = "SSL_SERVER_%s_DN_%s" % (prefix, key) ssl_environ[wsgikey] = value return ssl_environ - def makefile(self, sock, mode='r', bufsize=-1): + def makefile(self, sock, mode="r", bufsize=-1): """Return socket file object.""" - cls = ( - SSLFileobjectStreamReader - if 'r' in mode else - SSLFileobjectStreamWriter - ) + cls = SSLFileobjectStreamReader if "r" in mode else SSLFileobjectStreamWriter if SSL and isinstance(sock, ssl_conn_type): wrapped_socket = cls(sock, mode, bufsize) wrapped_socket.ssl_timeout = sock.gettimeout() diff --git a/cheroot/ssl/pyopenssl.pyi b/cheroot/ssl/pyopenssl.pyi index 107675c9bc..19c8d0f575 100644 --- a/cheroot/ssl/pyopenssl.pyi +++ b/cheroot/ssl/pyopenssl.pyi @@ -23,7 +23,13 @@ class SSLConnection: def __init__(self, *args) -> None: ... class pyOpenSSLAdapter(Adapter): - def __init__(self, certificate, private_key, certificate_chain: Any | None = ..., ciphers: Any | None = ...) -> None: ... + def __init__( + self, + certificate, + private_key, + certificate_chain: Any | None = ..., + ciphers: Any | None = ..., + ) -> None: ... def bind(self, sock): ... def wrap(self, sock): ... def get_environ(self): ... diff --git a/cheroot/test/_pytest_plugin.py b/cheroot/test/_pytest_plugin.py index 61f2efe126..9b6d2c541e 100644 --- a/cheroot/test/_pytest_plugin.py +++ b/cheroot/test/_pytest_plugin.py @@ -7,7 +7,7 @@ import pytest -pytest_version = tuple(map(int, pytest.__version__.split('.'))) +pytest_version = tuple(map(int, pytest.__version__.split("."))) def pytest_load_initial_conftests(early_config, parser, args): @@ -19,25 +19,27 @@ def pytest_load_initial_conftests(early_config, parser, args): # Refs: # * https://docs.pytest.org/en/stable/usage.html#unraisable # * https://github.com/pytest-dev/pytest/issues/5299 - early_config._inicache['filterwarnings'].extend(( - 'ignore:Exception in thread CP Server Thread-:' - 'pytest.PytestUnhandledThreadExceptionWarning:_pytest.threadexception', - 'ignore:Exception in thread Thread-:' - 'pytest.PytestUnhandledThreadExceptionWarning:_pytest.threadexception', - 'ignore:Exception ignored in. ' - '` happens. """ exc_instance = ( - None if simulated_exception is None - else simulated_exception(error_number, 'Simulated socket error') + None + if simulated_exception is None + else simulated_exception(error_number, "Simulated socket error") ) old_close_kernel_socket = ( - test_client.server_instance. - ConnectionClass._close_kernel_socket + test_client.server_instance.ConnectionClass._close_kernel_socket ) def _close_kernel_socket(self): monkeypatch.setattr( # `socket.shutdown` is read-only otherwise - self, 'socket', + self, + "socket", mocker.mock_module.Mock(wraps=self.socket), ) if exc_instance is not None: monkeypatch.setattr( - self.socket, 'shutdown', + self.socket, + "shutdown", mocker.mock_module.Mock(side_effect=exc_instance), ) - _close_kernel_socket.fin_spy = mocker.spy(self.socket, 'shutdown') + _close_kernel_socket.fin_spy = mocker.spy(self.socket, "shutdown") try: old_close_kernel_socket(self) @@ -673,28 +705,29 @@ def _close_kernel_socket(self): monkeypatch.setattr( test_client.server_instance.ConnectionClass, - '_close_kernel_socket', + "_close_kernel_socket", _close_kernel_socket, ) conn = test_client.get_connection() conn.auto_open = False conn.connect() - conn.send(b'GET /hello HTTP/1.1') - conn.send(('Host: %s' % conn.host).encode('ascii')) + conn.send(b"GET /hello HTTP/1.1") + conn.send(("Host: %s" % conn.host).encode("ascii")) conn.close() # Let the server attempt TCP shutdown: for _ in range(10 * (2 if IS_SLOW_ENV else 1)): time.sleep(0.1) - if hasattr(_close_kernel_socket, 'exception_leaked'): + if hasattr(_close_kernel_socket, "exception_leaked"): break if exc_instance is not None: # simulated by us assert _close_kernel_socket.fin_spy.spy_exception is exc_instance else: # real assert isinstance( - _close_kernel_socket.fin_spy.spy_exception, socket.error, + _close_kernel_socket.fin_spy.spy_exception, + socket.error, ) assert _close_kernel_socket.fin_spy.spy_exception.errno == error_number @@ -702,75 +735,78 @@ def _close_kernel_socket(self): def test_broken_connection_during_http_communication_fallback( # noqa: WPS118 - monkeypatch, - test_client, - testing_server, - wsgi_server_thread, + monkeypatch, + test_client, + testing_server, + wsgi_server_thread, ): """Test that unhandled internal error cascades into shutdown.""" + def _raise_connection_reset(*_args, **_kwargs): raise ConnectionResetError(666) def _read_request_line(self): - monkeypatch.setattr(self.conn.rfile, 'close', _raise_connection_reset) - monkeypatch.setattr(self.conn.wfile, 'write', _raise_connection_reset) + monkeypatch.setattr(self.conn.rfile, "close", _raise_connection_reset) + monkeypatch.setattr(self.conn.wfile, "write", _raise_connection_reset) _raise_connection_reset() monkeypatch.setattr( test_client.server_instance.ConnectionClass.RequestHandlerClass, - 'read_request_line', + "read_request_line", _read_request_line, ) - test_client.get_connection().send(b'GET / HTTP/1.1') + test_client.get_connection().send(b"GET / HTTP/1.1") wsgi_server_thread.join() # no extra logs upon server termination actual_log_entries = testing_server.error_log.calls[:] testing_server.error_log.calls.clear() # prevent post-test assertions expected_log_entries = ( - (logging.WARNING, r'^socket\.error 666$'), + (logging.WARNING, r"^socket\.error 666$"), ( logging.INFO, - '^Got a connection error while handling a connection ' - r'from .*:\d{1,5} \(666\)', + "^Got a connection error while handling a connection " + r"from .*:\d{1,5} \(666\)", ), ( logging.CRITICAL, - r'A fatal exception happened\. Setting the server interrupt flag ' - r'to ConnectionResetError\(666,?\) and giving up\.\n\nPlease, ' - 'report this on the Cheroot tracker at ' - r', ' - 'providing a full reproducer with as much context and details ' - r'as possible\.$', + r"A fatal exception happened\. Setting the server interrupt flag " + r"to ConnectionResetError\(666,?\) and giving up\.\n\nPlease, " + "report this on the Cheroot tracker at " + r", " + "providing a full reproducer with as much context and details " + r"as possible\.$", ), ) assert len(actual_log_entries) == len(expected_log_entries) for ( # noqa: WPS352 - (expected_log_level, expected_msg_regex), - (actual_msg, actual_log_level, _tb), + (expected_log_level, expected_msg_regex), + (actual_msg, actual_log_level, _tb), ) in zip(expected_log_entries, actual_log_entries): assert expected_log_level == actual_log_level - assert _matches_pattern(expected_msg_regex, actual_msg) is not None, ( - f'{actual_msg !r} does not match {expected_msg_regex !r}' - ) + assert ( + _matches_pattern(expected_msg_regex, actual_msg) is not None + ), f"{actual_msg !r} does not match {expected_msg_regex !r}" def test_kb_int_from_http_handler( - test_client, - testing_server, - wsgi_server_thread, + test_client, + testing_server, + wsgi_server_thread, ): """Test that a keyboard interrupt from HTTP handler causes shutdown.""" + def _trigger_kb_intr(_req, _resp): - raise KeyboardInterrupt('simulated test handler keyboard interrupt') - testing_server.wsgi_app.handlers['/kb_intr'] = _trigger_kb_intr + raise KeyboardInterrupt("simulated test handler keyboard interrupt") + + testing_server.wsgi_app.handlers["/kb_intr"] = _trigger_kb_intr http_conn = test_client.get_connection() - http_conn.putrequest('GET', '/kb_intr', skip_host=True) - http_conn.putheader('Host', http_conn.host) + http_conn.putrequest("GET", "/kb_intr", skip_host=True) + http_conn.putheader("Host", http_conn.host) http_conn.endheaders() wsgi_server_thread.join() # no extra logs upon server termination @@ -780,38 +816,38 @@ def _trigger_kb_intr(_req, _resp): expected_log_entries = ( ( logging.DEBUG, - '^Got a server shutdown request while handling a connection ' - r'from .*:\d{1,5} \(simulated test handler keyboard interrupt\)$', + "^Got a server shutdown request while handling a connection " + r"from .*:\d{1,5} \(simulated test handler keyboard interrupt\)$", ), ( logging.DEBUG, - '^Setting the server interrupt flag to KeyboardInterrupt' + "^Setting the server interrupt flag to KeyboardInterrupt" r"\('simulated test handler keyboard interrupt',?\)$", ), ( logging.INFO, - '^Keyboard Interrupt: shutting down$', + "^Keyboard Interrupt: shutting down$", ), ) assert len(actual_log_entries) == len(expected_log_entries) for ( # noqa: WPS352 - (expected_log_level, expected_msg_regex), - (actual_msg, actual_log_level, _tb), + (expected_log_level, expected_msg_regex), + (actual_msg, actual_log_level, _tb), ) in zip(expected_log_entries, actual_log_entries): assert expected_log_level == actual_log_level - assert _matches_pattern(expected_msg_regex, actual_msg) is not None, ( - f'{actual_msg !r} does not match {expected_msg_regex !r}' - ) + assert ( + _matches_pattern(expected_msg_regex, actual_msg) is not None + ), f"{actual_msg !r} does not match {expected_msg_regex !r}" def test_unhandled_exception_in_request_handler( - mocker, - monkeypatch, - test_client, - testing_server, - wsgi_server_thread, + mocker, + monkeypatch, + test_client, + testing_server, + wsgi_server_thread, ): """Ensure worker threads are resilient to in-handler exceptions.""" @@ -819,18 +855,18 @@ class SillyMistake(BaseException): # noqa: WPS418, WPS431 """A simulated crash within an HTTP handler.""" def _trigger_scary_exc(_req, _resp): - raise SillyMistake('simulated unhandled exception 💣 in test handler') + raise SillyMistake("simulated unhandled exception 💣 in test handler") - testing_server.wsgi_app.handlers['/scary_exc'] = _trigger_scary_exc + testing_server.wsgi_app.handlers["/scary_exc"] = _trigger_scary_exc server_connection_close_spy = mocker.spy( test_client.server_instance.ConnectionClass, - 'close', + "close", ) http_conn = test_client.get_connection() - http_conn.putrequest('GET', '/scary_exc', skip_host=True) - http_conn.putheader('Host', http_conn.host) + http_conn.putrequest("GET", "/scary_exc", skip_host=True) + http_conn.putheader("Host", http_conn.host) http_conn.endheaders() # NOTE: This spy ensure the log entry gets recorded before we're testing @@ -843,7 +879,7 @@ def _trigger_scary_exc(_req, _resp): while testing_server.requests.idle < 10: # noqa: WPS328 pass assert len(testing_server.requests._threads) == 10 - testing_server.interrupt = SystemExit('test requesting shutdown') + testing_server.interrupt = SystemExit("test requesting shutdown") assert not testing_server.requests._threads wsgi_server_thread.join() # no extra logs upon server termination @@ -853,45 +889,41 @@ def _trigger_scary_exc(_req, _resp): expected_log_entries = ( ( logging.ERROR, - '^Unhandled error while processing an incoming connection ' - 'SillyMistake' + "^Unhandled error while processing an incoming connection " + "SillyMistake" r"\('simulated unhandled exception 💣 in test handler',?\)$", ), ( logging.INFO, - '^SystemExit raised: shutting down$', + "^SystemExit raised: shutting down$", ), ) assert len(actual_log_entries) == len(expected_log_entries) for ( # noqa: WPS352 - (expected_log_level, expected_msg_regex), - (actual_msg, actual_log_level, _tb), + (expected_log_level, expected_msg_regex), + (actual_msg, actual_log_level, _tb), ) in zip(expected_log_entries, actual_log_entries): assert expected_log_level == actual_log_level - assert _matches_pattern(expected_msg_regex, actual_msg) is not None, ( - f'{actual_msg !r} does not match {expected_msg_regex !r}' - ) + assert ( + _matches_pattern(expected_msg_regex, actual_msg) is not None + ), f"{actual_msg !r} does not match {expected_msg_regex !r}" def test_remains_alive_post_unhandled_exception( - mocker, - monkeypatch, - test_client, - testing_server, - wsgi_server_thread, + mocker, + monkeypatch, + test_client, + testing_server, + wsgi_server_thread, ): """Ensure worker threads are resilient to unhandled exceptions.""" class ScaryCrash(BaseException): # noqa: WPS418, WPS431 """A simulated crash during HTTP parsing.""" - _orig_read_request_line = ( - test_client.server_instance. - ConnectionClass.RequestHandlerClass. - read_request_line - ) + _orig_read_request_line = test_client.server_instance.ConnectionClass.RequestHandlerClass.read_request_line def _read_request_line(self): _orig_read_request_line(self) @@ -899,19 +931,19 @@ def _read_request_line(self): monkeypatch.setattr( test_client.server_instance.ConnectionClass.RequestHandlerClass, - 'read_request_line', + "read_request_line", _read_request_line, ) server_connection_close_spy = mocker.spy( test_client.server_instance.ConnectionClass, - 'close', + "close", ) # NOTE: The initial worker thread count is 10. assert len(testing_server.requests._threads) == 10 - test_client.get_connection().send(b'GET / HTTP/1.1') + test_client.get_connection().send(b"GET / HTTP/1.1") # NOTE: This spy ensure the log entry gets recorded before we're testing # NOTE: them and before server shutdown, preserving their order and making @@ -924,10 +956,9 @@ def _read_request_line(self): pass assert len(testing_server.requests._threads) == 10 assert all( - worker_thread.is_alive() - for worker_thread in testing_server.requests._threads + worker_thread.is_alive() for worker_thread in testing_server.requests._threads ) - testing_server.interrupt = SystemExit('test requesting shutdown') + testing_server.interrupt = SystemExit("test requesting shutdown") assert not testing_server.requests._threads wsgi_server_thread.join() # no extra logs upon server termination @@ -937,29 +968,29 @@ def _read_request_line(self): expected_log_entries = ( ( logging.ERROR, - '^Unhandled error while processing an incoming connection ' - r'ScaryCrash\(666,?\)$', + "^Unhandled error while processing an incoming connection " + r"ScaryCrash\(666,?\)$", ), ( logging.INFO, - '^SystemExit raised: shutting down$', + "^SystemExit raised: shutting down$", ), ) assert len(actual_log_entries) == len(expected_log_entries) for ( # noqa: WPS352 - (expected_log_level, expected_msg_regex), - (actual_msg, actual_log_level, _tb), + (expected_log_level, expected_msg_regex), + (actual_msg, actual_log_level, _tb), ) in zip(expected_log_entries, actual_log_entries): assert expected_log_level == actual_log_level - assert _matches_pattern(expected_msg_regex, actual_msg) is not None, ( - f'{actual_msg !r} does not match {expected_msg_regex !r}' - ) + assert ( + _matches_pattern(expected_msg_regex, actual_msg) is not None + ), f"{actual_msg !r} does not match {expected_msg_regex !r}" @pytest.mark.parametrize( - 'timeout_before_headers', + "timeout_before_headers", ( True, False, @@ -976,15 +1007,15 @@ def test_HTTP11_Timeout(test_client, timeout_before_headers): if not timeout_before_headers: # Connect but send half the headers only. - conn.send(b'GET /hello HTTP/1.1') - conn.send(('Host: %s' % conn.host).encode('ascii')) + conn.send(b"GET /hello HTTP/1.1") + conn.send(("Host: %s" % conn.host).encode("ascii")) # else: Connect but send nothing. # Wait for our socket timeout time.sleep(timeout * 2) # The request should have returned 408 already. - response = conn.response_class(conn.sock, method='GET') + response = conn.response_class(conn.sock, method="GET") response.begin() assert response.status == 408 conn.close() @@ -999,10 +1030,10 @@ def test_HTTP11_Timeout_after_request(test_client): # Make an initial request conn = test_client.get_connection() - conn.putrequest('GET', '/timeout?t=%s' % timeout, skip_host=True) - conn.putheader('Host', conn.host) + conn.putrequest("GET", "/timeout?t=%s" % timeout, skip_host=True) + conn.putheader("Host", conn.host) conn.endheaders() - response = conn.response_class(conn.sock, method='GET') + response = conn.response_class(conn.sock, method="GET") response.begin() assert response.status == 200 actual_body = response.read() @@ -1010,24 +1041,24 @@ def test_HTTP11_Timeout_after_request(test_client): assert actual_body == expected_body # Make a second request on the same socket - conn._output(b'GET /hello HTTP/1.1') - conn._output(('Host: %s' % conn.host).encode('ascii')) + conn._output(b"GET /hello HTTP/1.1") + conn._output(("Host: %s" % conn.host).encode("ascii")) conn._send_output() - response = conn.response_class(conn.sock, method='GET') + response = conn.response_class(conn.sock, method="GET") response.begin() assert response.status == 200 actual_body = response.read() - expected_body = b'Hello, world!' + expected_body = b"Hello, world!" assert actual_body == expected_body # Wait for our socket timeout time.sleep(timeout * 2) # Make another request on the same socket, which should error - conn._output(b'GET /hello HTTP/1.1') - conn._output(('Host: %s' % conn.host).encode('ascii')) + conn._output(b"GET /hello HTTP/1.1") + conn._output(("Host: %s" % conn.host).encode("ascii")) conn._send_output() - response = conn.response_class(conn.sock, method='GET') + response = conn.response_class(conn.sock, method="GET") try: response.begin() except (socket.error, http.client.BadStatusLine): @@ -1042,10 +1073,10 @@ def test_HTTP11_Timeout_after_request(test_client): # Make another request on a new socket, which should work conn = test_client.get_connection() - conn.putrequest('GET', '/pov', skip_host=True) - conn.putheader('Host', conn.host) + conn.putrequest("GET", "/pov", skip_host=True) + conn.putheader("Host", conn.host) conn.endheaders() - response = conn.response_class(conn.sock, method='GET') + response = conn.response_class(conn.sock, method="GET") response.begin() assert response.status == 200 actual_body = response.read() @@ -1054,10 +1085,10 @@ def test_HTTP11_Timeout_after_request(test_client): # Make another request on the same socket, # but timeout on the headers - conn.send(b'GET /hello HTTP/1.1') + conn.send(b"GET /hello HTTP/1.1") # Wait for our socket timeout time.sleep(timeout * 2) - response = conn.response_class(conn.sock, method='GET') + response = conn.response_class(conn.sock, method="GET") try: response.begin() except (socket.error, http.client.BadStatusLine): @@ -1072,10 +1103,10 @@ def test_HTTP11_Timeout_after_request(test_client): # Retry the request on a new connection, which should work conn = test_client.get_connection() - conn.putrequest('GET', '/pov', skip_host=True) - conn.putheader('Host', conn.host) + conn.putrequest("GET", "/pov", skip_host=True) + conn.putheader("Host", conn.host) conn.endheaders() - response = conn.response_class(conn.sock, method='GET') + response = conn.response_class(conn.sock, method="GET") response.begin() assert response.status == 200 actual_body = response.read() @@ -1092,36 +1123,36 @@ def test_HTTP11_pipelining(test_client): conn = test_client.get_connection() # Put request 1 - conn.putrequest('GET', '/hello', skip_host=True) - conn.putheader('Host', conn.host) + conn.putrequest("GET", "/hello", skip_host=True) + conn.putheader("Host", conn.host) conn.endheaders() for trial in range(5): # Put next request conn._output( - ('GET /hello?%s HTTP/1.1' % trial).encode('iso-8859-1'), + ("GET /hello?%s HTTP/1.1" % trial).encode("iso-8859-1"), ) - conn._output(('Host: %s' % conn.host).encode('ascii')) + conn._output(("Host: %s" % conn.host).encode("ascii")) conn._send_output() # Retrieve previous response - response = conn.response_class(conn.sock, method='GET') + response = conn.response_class(conn.sock, method="GET") # there is a bug in python3 regarding the buffering of # ``conn.sock``. Until that bug get's fixed we will # monkey patch the ``response`` instance. # https://bugs.python.org/issue23377 - response.fp = conn.sock.makefile('rb', 0) + response.fp = conn.sock.makefile("rb", 0) response.begin() body = response.read(13) assert response.status == 200 - assert body == b'Hello, world!' + assert body == b"Hello, world!" # Retrieve final response - response = conn.response_class(conn.sock, method='GET') + response = conn.response_class(conn.sock, method="GET") response.begin() body = response.read() assert response.status == 200 - assert body == b'Hello, world!' + assert body == b"Hello, world!" conn.close() @@ -1133,26 +1164,26 @@ def test_100_Continue(test_client): # Try a page without an Expect request header first. # Note that http.client's response.begin automatically ignores # 100 Continue responses, so we must manually check for it. - conn.putrequest('POST', '/upload', skip_host=True) - conn.putheader('Host', conn.host) - conn.putheader('Content-Type', 'text/plain') - conn.putheader('Content-Length', '4') + conn.putrequest("POST", "/upload", skip_host=True) + conn.putheader("Host", conn.host) + conn.putheader("Content-Type", "text/plain") + conn.putheader("Content-Length", "4") conn.endheaders() conn.send(b"d'oh") - response = conn.response_class(conn.sock, method='POST') + response = conn.response_class(conn.sock, method="POST") _version, status, _reason = response._read_status() assert status != 100 conn.close() # Now try a page with an Expect header... conn.connect() - conn.putrequest('POST', '/upload', skip_host=True) - conn.putheader('Host', conn.host) - conn.putheader('Content-Type', 'text/plain') - conn.putheader('Content-Length', '17') - conn.putheader('Expect', '100-continue') + conn.putrequest("POST", "/upload", skip_host=True) + conn.putheader("Host", conn.host) + conn.putheader("Content-Type", "text/plain") + conn.putheader("Content-Length", "17") + conn.putheader("Expect", "100-continue") conn.endheaders() - response = conn.response_class(conn.sock, method='POST') + response = conn.response_class(conn.sock, method="POST") # ...assert and then skip the 100 response version, status, reason = response._read_status() @@ -1161,14 +1192,13 @@ def test_100_Continue(test_client): line = response.fp.readline().strip() if line: pytest.fail( - '100 Continue should not output any headers. Got %r' % - line, + "100 Continue should not output any headers. Got %r" % line, ) else: break # ...send the body - body = b'I am a small file' + body = b"I am a small file" conn.send(body) # ...get the final response @@ -1182,7 +1212,7 @@ def test_100_Continue(test_client): @pytest.mark.parametrize( - 'max_request_body_size', + "max_request_body_size", ( 0, 1001, @@ -1197,13 +1227,13 @@ def test_readall_or_close(test_client, max_request_body_size): conn = test_client.get_connection() # Get a POST page with an error - conn.putrequest('POST', '/err_before_read', skip_host=True) - conn.putheader('Host', conn.host) - conn.putheader('Content-Type', 'text/plain') - conn.putheader('Content-Length', '1000') - conn.putheader('Expect', '100-continue') + conn.putrequest("POST", "/err_before_read", skip_host=True) + conn.putheader("Host", conn.host) + conn.putheader("Content-Type", "text/plain") + conn.putheader("Content-Length", "1000") + conn.putheader("Expect", "100-continue") conn.endheaders() - response = conn.response_class(conn.sock, method='POST') + response = conn.response_class(conn.sock, method="POST") # ...assert and then skip the 100 response _version, status, _reason = response._read_status() @@ -1213,7 +1243,7 @@ def test_readall_or_close(test_client, max_request_body_size): skip = response.fp.readline().strip() # ...send the body - conn.send(b'x' * 1000) + conn.send(b"x" * 1000) # ...get the final response response.begin() @@ -1222,13 +1252,13 @@ def test_readall_or_close(test_client, max_request_body_size): assert actual_status == 500 # Now try a working page with an Expect header... - conn._output(b'POST /upload HTTP/1.1') - conn._output(('Host: %s' % conn.host).encode('ascii')) - conn._output(b'Content-Type: text/plain') - conn._output(b'Content-Length: 17') - conn._output(b'Expect: 100-continue') + conn._output(b"POST /upload HTTP/1.1") + conn._output(("Host: %s" % conn.host).encode("ascii")) + conn._output(b"Content-Type: text/plain") + conn._output(b"Content-Length: 17") + conn._output(b"Expect: 100-continue") conn._send_output() - response = conn.response_class(conn.sock, method='POST') + response = conn.response_class(conn.sock, method="POST") # ...assert and then skip the 100 response version, status, reason = response._read_status() @@ -1238,7 +1268,7 @@ def test_readall_or_close(test_client, max_request_body_size): skip = response.fp.readline().strip() # ...send the body - body = b'I am a small file' + body = b"I am a small file" conn.send(body) # ...get the final response @@ -1262,33 +1292,36 @@ def test_No_Message_Body(test_client): # Make the first request and assert there's no "Connection: close". status_line, actual_headers, actual_resp_body = test_client.get( - '/pov', http_conn=http_connection, + "/pov", + http_conn=http_connection, ) actual_status = int(status_line[:3]) assert actual_status == 200 - assert status_line[4:] == 'OK' + assert status_line[4:] == "OK" assert actual_resp_body == pov.encode() - assert not header_exists('Connection', actual_headers) + assert not header_exists("Connection", actual_headers) # Make a 204 request on the same connection. status_line, actual_headers, actual_resp_body = test_client.get( - '/custom/204', http_conn=http_connection, + "/custom/204", + http_conn=http_connection, ) actual_status = int(status_line[:3]) assert actual_status == 204 - assert not header_exists('Content-Length', actual_headers) - assert actual_resp_body == b'' - assert not header_exists('Connection', actual_headers) + assert not header_exists("Content-Length", actual_headers) + assert actual_resp_body == b"" + assert not header_exists("Connection", actual_headers) # Make a 304 request on the same connection. status_line, actual_headers, actual_resp_body = test_client.get( - '/custom/304', http_conn=http_connection, + "/custom/304", + http_conn=http_connection, ) actual_status = int(status_line[:3]) assert actual_status == 304 - assert not header_exists('Content-Length', actual_headers) - assert actual_resp_body == b'' - assert not header_exists('Connection', actual_headers) + assert not header_exists("Content-Length", actual_headers) + assert actual_resp_body == b"" + assert not header_exists("Connection", actual_headers) # Prevent the resource warnings: http_connection.close() @@ -1310,35 +1343,35 @@ def test_Chunked_Encoding(test_client): # Try a normal chunked request (with extensions) body = ( - b'8;key=value\r\nxx\r\nxxxx\r\n5\r\nyyyyy\r\n0\r\n' - b'Content-Type: application/json\r\n' - b'\r\n' + b"8;key=value\r\nxx\r\nxxxx\r\n5\r\nyyyyy\r\n0\r\n" + b"Content-Type: application/json\r\n" + b"\r\n" ) - conn.putrequest('POST', '/upload', skip_host=True) - conn.putheader('Host', conn.host) - conn.putheader('Transfer-Encoding', 'chunked') - conn.putheader('Trailer', 'Content-Type') + conn.putrequest("POST", "/upload", skip_host=True) + conn.putheader("Host", conn.host) + conn.putheader("Transfer-Encoding", "chunked") + conn.putheader("Trailer", "Content-Type") # Note that this is somewhat malformed: # we shouldn't be sending Content-Length. # RFC 2616 says the server should ignore it. - conn.putheader('Content-Length', '3') + conn.putheader("Content-Length", "3") conn.endheaders() conn.send(body) response = conn.getresponse() status_line, _actual_headers, actual_resp_body = webtest.shb(response) actual_status = int(status_line[:3]) assert actual_status == 200 - assert status_line[4:] == 'OK' - expected_resp_body = ("thanks for '%s'" % b'xx\r\nxxxxyyyyy').encode() + assert status_line[4:] == "OK" + expected_resp_body = ("thanks for '%s'" % b"xx\r\nxxxxyyyyy").encode() assert actual_resp_body == expected_resp_body # Try a chunked request that exceeds server.max_request_body_size. # Note that the delimiters and trailer are included. - body = b'\r\n'.join((b'3e3', b'x' * 995, b'0', b'', b'')) - conn.putrequest('POST', '/upload', skip_host=True) - conn.putheader('Host', conn.host) - conn.putheader('Transfer-Encoding', 'chunked') - conn.putheader('Content-Type', 'text/plain') + body = b"\r\n".join((b"3e3", b"x" * 995, b"0", b"", b"")) + conn.putrequest("POST", "/upload", skip_host=True) + conn.putheader("Host", conn.host) + conn.putheader("Transfer-Encoding", "chunked") + conn.putheader("Content-Type", "text/plain") # Chunked requests don't need a content-length # conn.putheader("Content-Length", len(body)) conn.endheaders() @@ -1359,18 +1392,17 @@ def test_Content_Length_in(test_client): # Initialize a persistent HTTP connection conn = test_client.get_connection() - conn.putrequest('POST', '/upload', skip_host=True) - conn.putheader('Host', conn.host) - conn.putheader('Content-Type', 'text/plain') - conn.putheader('Content-Length', '9999') + conn.putrequest("POST", "/upload", skip_host=True) + conn.putheader("Host", conn.host) + conn.putheader("Content-Type", "text/plain") + conn.putheader("Content-Length", "9999") conn.endheaders() response = conn.getresponse() status_line, _actual_headers, actual_resp_body = webtest.shb(response) actual_status = int(status_line[:3]) assert actual_status == 413 expected_resp_body = ( - b'The entity sent with the request exceeds ' - b'the maximum allowed bytes.' + b"The entity sent with the request exceeds " b"the maximum allowed bytes." ) assert actual_resp_body == expected_resp_body conn.close() @@ -1379,42 +1411,45 @@ def test_Content_Length_in(test_client): def test_Content_Length_not_int(test_client): """Test that malicious Content-Length header returns 400.""" status_line, _actual_headers, actual_resp_body = test_client.post( - '/upload', + "/upload", headers=[ - ('Content-Type', 'text/plain'), - ('Content-Length', 'not-an-integer'), + ("Content-Type", "text/plain"), + ("Content-Length", "not-an-integer"), ], ) actual_status = int(status_line[:3]) assert actual_status == 400 - assert actual_resp_body == b'Malformed Content-Length Header.' + assert actual_resp_body == b"Malformed Content-Length Header." @pytest.mark.parametrize( - ('uri', 'expected_resp_status', 'expected_resp_body'), + ("uri", "expected_resp_status", "expected_resp_body"), ( ( - '/wrong_cl_buffered', 500, + "/wrong_cl_buffered", + 500, ( - b'The requested resource returned more bytes than the ' - b'declared Content-Length.' + b"The requested resource returned more bytes than the " + b"declared Content-Length." ), ), - ('/wrong_cl_unbuffered', 200, b'I too'), + ("/wrong_cl_unbuffered", 200, b"I too"), ), ) def test_Content_Length_out( test_client, - uri, expected_resp_status, expected_resp_body, + uri, + expected_resp_status, + expected_resp_body, ): """Test response with Content-Length less than the response body. (non-chunked response) """ conn = test_client.get_connection() - conn.putrequest('GET', uri, skip_host=True) - conn.putheader('Host', conn.host) + conn.putrequest("GET", uri, skip_host=True) + conn.putheader("Host", conn.host) conn.endheaders() response = conn.getresponse() @@ -1429,25 +1464,26 @@ def test_Content_Length_out( # the server logs the exception that we had verified from the # client perspective. Tell the error_log verification that # it can ignore that message. - test_client.server_instance.error_log.ignored_msgs.extend(( - # Python 3.7+: - "ValueError('Response body exceeds the declared Content-Length.')", - # Python 2.7-3.6 (macOS?): - "ValueError('Response body exceeds the declared Content-Length.',)", - )) + test_client.server_instance.error_log.ignored_msgs.extend( + ( + # Python 3.7+: + "ValueError('Response body exceeds the declared Content-Length.')", + # Python 2.7-3.6 (macOS?): + "ValueError('Response body exceeds the declared Content-Length.',)", + ) + ) @pytest.mark.xfail( - reason='Sometimes this test fails due to low timeout. ' - 'Ref: https://github.com/cherrypy/cherrypy/issues/598', + reason="Sometimes this test fails due to low timeout. " + "Ref: https://github.com/cherrypy/cherrypy/issues/598", ) def test_598(test_client): """Test serving large file with a read timeout in place.""" # Initialize a persistent HTTP connection conn = test_client.get_connection() remote_data_conn = urllib.request.urlopen( - '%s://%s:%s/one_megabyte_of_a' - % ('http', conn.host, conn.port), + "%s://%s:%s/one_megabyte_of_a" % ("http", conn.host, conn.port), ) buf = remote_data_conn.read(512) time.sleep(timeout * 0.6) @@ -1460,16 +1496,16 @@ def test_598(test_client): remaining -= len(data) assert len(buf) == 1024 * 1024 - assert buf == b'a' * 1024 * 1024 + assert buf == b"a" * 1024 * 1024 assert remaining == 0 remote_data_conn.close() @pytest.mark.parametrize( - 'invalid_terminator', + "invalid_terminator", ( - b'\n\n', - b'\r\n\n', + b"\n\n", + b"\r\n\n", ), ) def test_No_CRLF(test_client, invalid_terminator): @@ -1477,11 +1513,11 @@ def test_No_CRLF(test_client, invalid_terminator): # Initialize a persistent HTTP connection conn = test_client.get_connection() - conn.send(b'GET /hello HTTP/1.1%s' % invalid_terminator) - response = conn.response_class(conn.sock, method='GET') + conn.send(b"GET /hello HTTP/1.1%s" % invalid_terminator) + response = conn.response_class(conn.sock, method="GET") response.begin() actual_resp_body = response.read() - expected_resp_body = b'HTTP requires CRLF terminators' + expected_resp_body = b"HTTP requires CRLF terminators" assert actual_resp_body == expected_resp_body conn.close() @@ -1499,7 +1535,7 @@ def __call__(self, timeout): """Intercept the calls to selector.select.""" if self.request_served: self.os_error_triggered = True - raise OSError('Error while selecting the client socket.') + raise OSError("Error while selecting the client socket.") return self.original_select(timeout) @@ -1516,9 +1552,14 @@ def __init__(self, original_get_map): def __call__(self): """Intercept the calls to selector.get_map.""" sabotage_targets = ( - conn for _, (_, _, _, conn) in self.original_get_map().items() - if isinstance(conn, cheroot.server.HTTPConnection) - ) if self.sabotage_conn and not self.conn_closed else () + ( + conn + for _, (_, _, _, conn) in self.original_get_map().items() + if isinstance(conn, cheroot.server.HTTPConnection) + ) + if self.sabotage_conn and not self.conn_closed + else () + ) for conn in sabotage_targets: # close the socket to cause OSError @@ -1539,7 +1580,7 @@ def test_invalid_selected_connection(test_client, monkeypatch): ) monkeypatch.setattr( test_client.server_instance._connections._selector, - 'select', + "select", faux_select, ) @@ -1550,17 +1591,18 @@ def test_invalid_selected_connection(test_client, monkeypatch): monkeypatch.setattr( test_client.server_instance._connections._selector._selector, - 'get_map', + "get_map", faux_get_map, ) # request a page with connection keep-alive to make sure # we'll have a connection to be modified. resp_status, _resp_headers, _resp_body = test_client.request( - '/page1', headers=[('Connection', 'Keep-Alive')], + "/page1", + headers=[("Connection", "Keep-Alive")], ) - assert resp_status == '200 OK' + assert resp_status == "200 OK" # trigger the internal errors faux_get_map.sabotage_conn = faux_select.request_served = True # give time to make sure the error gets handled diff --git a/cheroot/test/test_core.py b/cheroot/test/test_core.py index d33647888e..fe367129a5 100644 --- a/cheroot/test/test_core.py +++ b/cheroot/test/test_core.py @@ -22,24 +22,24 @@ class HelloController(helper.Controller): def hello(req, resp): """Render Hello world.""" - return 'Hello world!' + return "Hello world!" def body_required(req, resp): """Render Hello world or set 411.""" - if req.environ.get('Content-Length', None) is None: - resp.status = '411 Length Required' + if req.environ.get("Content-Length", None) is None: + resp.status = "411 Length Required" return - return 'Hello world!' + return "Hello world!" def query_string(req, resp): """Render QUERY_STRING value.""" - return req.environ.get('QUERY_STRING', '') + return req.environ.get("QUERY_STRING", "") def asterisk(req, resp): """Render request method value.""" # pylint: disable=possibly-unused-variable - method = req.environ.get('REQUEST_METHOD', 'NO METHOD FOUND') - tmpl = 'Got asterisk URI path with {method} method' + method = req.environ.get("REQUEST_METHOD", "NO METHOD FOUND") + tmpl = "Got asterisk URI path with {method} method" return tmpl.format(**locals()) def _munge(string): @@ -48,27 +48,27 @@ def _munge(string): WSGI 1.0 is a mess around unicode. Create endpoints that match the PATH_INFO that it produces. """ - return string.encode('utf-8').decode('latin-1') + return string.encode("utf-8").decode("latin-1") handlers = { - '/hello': hello, - '/no_body': hello, - '/body_required': body_required, - '/query_string': query_string, + "/hello": hello, + "/no_body": hello, + "/body_required": body_required, + "/query_string": query_string, # FIXME: Unignore the pylint rules in pylint >= 2.15.4. # Refs: # * https://github.com/PyCQA/pylint/issues/6592 # * https://github.com/PyCQA/pylint/pull/7395 # pylint: disable-next=too-many-function-args - _munge('/привіт'): hello, + _munge("/привіт"): hello, # pylint: disable-next=too-many-function-args - _munge('/Юххууу'): hello, - '/\xa0Ðblah key 0 900 4 data': hello, - '/*': asterisk, + _munge("/Юххууу"): hello, + "/\xa0Ðblah key 0 900 4 data": hello, + "/*": asterisk, } -def _get_http_response(connection, method='GET'): +def _get_http_response(connection, method="GET"): return connection.response_class(connection.sock, method=method) @@ -105,36 +105,36 @@ def test_client_with_defaults(testing_server_with_defaults): def test_http_connect_request(test_client): """Check that CONNECT query results in Method Not Allowed status.""" - status_line = test_client.connect('/anything')[0] + status_line = test_client.connect("/anything")[0] actual_status = int(status_line[:3]) assert actual_status == 405 def test_normal_request(test_client): """Check that normal GET query succeeds.""" - status_line, _, actual_resp_body = test_client.get('/hello') + status_line, _, actual_resp_body = test_client.get("/hello") actual_status = int(status_line[:3]) assert actual_status == HTTP_OK - assert actual_resp_body == b'Hello world!' + assert actual_resp_body == b"Hello world!" def test_query_string_request(test_client): """Check that GET param is parsed well.""" status_line, _, actual_resp_body = test_client.get( - '/query_string?test=True', + "/query_string?test=True", ) actual_status = int(status_line[:3]) assert actual_status == HTTP_OK - assert actual_resp_body == b'test=True' + assert actual_resp_body == b"test=True" @pytest.mark.parametrize( - 'uri', + "uri", ( - '/hello', # plain - '/query_string?test=True', # query - '/{0}?{1}={2}'.format( # quoted unicode - *map(urllib.parse.quote, ('Юххууу', 'ї', 'йо')), + "/hello", # plain + "/query_string?test=True", # query + "/{0}?{1}={2}".format( # quoted unicode + *map(urllib.parse.quote, ("Юххууу", "ї", "йо")), ), ), ) @@ -161,16 +161,16 @@ def test_parse_uri_unsafe_uri(test_client): which would be a security issue otherwise. """ c = test_client.get_connection() - resource = '/\xa0Ðblah key 0 900 4 data'.encode('latin-1') + resource = "/\xa0Ðblah key 0 900 4 data".encode("latin-1") quoted = urllib.parse.quote(resource) - assert quoted == '/%A0%D0blah%20key%200%20900%204%20data' - request = 'GET {quoted} HTTP/1.1'.format(**locals()) - c._output(request.encode('utf-8')) + assert quoted == "/%A0%D0blah%20key%200%20900%204%20data" + request = "GET {quoted} HTTP/1.1".format(**locals()) + c._output(request.encode("utf-8")) c._send_output() - response = _get_http_response(c, method='GET') + response = _get_http_response(c, method="GET") response.begin() assert response.status == HTTP_OK - assert response.read(12) == b'Hello world!' + assert response.read(12) == b"Hello world!" c.close() @@ -180,20 +180,20 @@ def test_parse_uri_invalid_uri(test_client): Invalid request line test case: it should only contain US-ASCII. """ c = test_client.get_connection() - c._output(u'GET /йопта! HTTP/1.1'.encode('utf-8')) + c._output("GET /йопта! HTTP/1.1".encode("utf-8")) c._send_output() - response = _get_http_response(c, method='GET') + response = _get_http_response(c, method="GET") response.begin() assert response.status == HTTP_BAD_REQUEST - assert response.read(21) == b'Malformed Request-URI' + assert response.read(21) == b"Malformed Request-URI" c.close() @pytest.mark.parametrize( - 'uri', + "uri", ( - 'hello', # ascii - 'привіт', # non-ascii + "hello", # ascii + "привіт", # non-ascii ), ) def test_parse_no_leading_slash_invalid(test_client, uri): @@ -206,7 +206,7 @@ def test_parse_no_leading_slash_invalid(test_client, uri): ) actual_status = int(status_line[:3]) assert actual_status == HTTP_BAD_REQUEST - assert b'starting with a slash' in actual_resp_body + assert b"starting with a slash" in actual_resp_body def test_parse_uri_absolute_uri(test_client): @@ -214,30 +214,30 @@ def test_parse_uri_absolute_uri(test_client): Only proxy servers should allow this. """ - status_line, _, actual_resp_body = test_client.get('http://google.com/') + status_line, _, actual_resp_body = test_client.get("http://google.com/") actual_status = int(status_line[:3]) assert actual_status == HTTP_BAD_REQUEST - expected_body = b'Absolute URI not allowed if server is not a proxy.' + expected_body = b"Absolute URI not allowed if server is not a proxy." assert actual_resp_body == expected_body def test_parse_uri_asterisk_uri(test_client): """Check that server responds with OK to OPTIONS with "*" Absolute URI.""" - status_line, _, actual_resp_body = test_client.options('*') + status_line, _, actual_resp_body = test_client.options("*") actual_status = int(status_line[:3]) assert actual_status == HTTP_OK - expected_body = b'Got asterisk URI path with OPTIONS method' + expected_body = b"Got asterisk URI path with OPTIONS method" assert actual_resp_body == expected_body def test_parse_uri_fragment_uri(test_client): """Check that server responds with Bad Request to URI with fragment.""" status_line, _, actual_resp_body = test_client.get( - '/hello?test=something#fake', + "/hello?test=something#fake", ) actual_status = int(status_line[:3]) assert actual_status == HTTP_BAD_REQUEST - expected_body = b'Illegal #fragment in Request-URI.' + expected_body = b"Illegal #fragment in Request-URI." assert actual_resp_body == expected_body @@ -249,12 +249,12 @@ def test_no_content_length(test_client): # # Send a message with neither header and no body. c = test_client.get_connection() - c.request('POST', '/no_body') + c.request("POST", "/no_body") response = c.getresponse() actual_resp_body = response.read() actual_status = response.status assert actual_status == HTTP_OK - assert actual_resp_body == b'Hello world!' + assert actual_resp_body == b"Hello world!" c.close() # deal with the resource warning @@ -266,7 +266,7 @@ def test_content_length_required(test_client): # with 411 Length Required. c = test_client.get_connection() - c.request('POST', '/body_required') + c.request("POST", "/body_required") response = c.getresponse() response.read() @@ -277,7 +277,7 @@ def test_content_length_required(test_client): @pytest.mark.xfail( - reason='https://github.com/cherrypy/cheroot/issues/106', + reason="https://github.com/cherrypy/cheroot/issues/106", strict=False, # sometimes it passes ) def test_large_request(test_client_with_defaults): @@ -288,8 +288,8 @@ def test_large_request(test_client_with_defaults): # We expect that this should instead return that the request is too # large. c = test_client_with_defaults.get_connection() - c.putrequest('GET', '/hello') - c.putheader('Content-Length', str(2**64)) + c.putrequest("GET", "/hello") + c.putheader("Content-Length", str(2**64)) c.endheaders() response = c.getresponse() @@ -299,35 +299,41 @@ def test_large_request(test_client_with_defaults): @pytest.mark.parametrize( - ('request_line', 'status_code', 'expected_body'), + ("request_line", "status_code", "expected_body"), ( ( - b'GET /', # missing proto - HTTP_BAD_REQUEST, b'Malformed Request-Line', + b"GET /", # missing proto + HTTP_BAD_REQUEST, + b"Malformed Request-Line", ), ( - b'GET / HTTPS/1.1', # invalid proto - HTTP_BAD_REQUEST, b'Malformed Request-Line: bad protocol', + b"GET / HTTPS/1.1", # invalid proto + HTTP_BAD_REQUEST, + b"Malformed Request-Line: bad protocol", ), ( - b'GET / HTTP/1', # invalid version - HTTP_BAD_REQUEST, b'Malformed Request-Line: bad version', + b"GET / HTTP/1", # invalid version + HTTP_BAD_REQUEST, + b"Malformed Request-Line: bad version", ), ( - b'GET / HTTP/2.15', # invalid ver - HTTP_VERSION_NOT_SUPPORTED, b'Cannot fulfill request', + b"GET / HTTP/2.15", # invalid ver + HTTP_VERSION_NOT_SUPPORTED, + b"Cannot fulfill request", ), ), ) def test_malformed_request_line( - test_client, request_line, - status_code, expected_body, + test_client, + request_line, + status_code, + expected_body, ): """Test missing or invalid HTTP version in Request-Line.""" c = test_client.get_connection() c._output(request_line) c._send_output() - response = _get_http_response(c, method='GET') + response = _get_http_response(c, method="GET") response.begin() assert response.status == status_code assert response.read(len(expected_body)) == expected_body @@ -337,15 +343,15 @@ def test_malformed_request_line( def test_malformed_http_method(test_client): """Test non-uppercase HTTP method.""" c = test_client.get_connection() - c.putrequest('GeT', '/malformed_method_case') - c.putheader('Content-Type', 'text/plain') + c.putrequest("GeT", "/malformed_method_case") + c.putheader("Content-Type", "text/plain") c.endheaders() response = c.getresponse() actual_status = response.status assert actual_status == HTTP_BAD_REQUEST actual_resp_body = response.read(21) - assert actual_resp_body == b'Malformed method name' + assert actual_resp_body == b"Malformed method name" c.close() # deal with the resource warning @@ -353,17 +359,17 @@ def test_malformed_http_method(test_client): def test_malformed_header(test_client): """Check that broken HTTP header results in Bad Request.""" c = test_client.get_connection() - c.putrequest('GET', '/') - c.putheader('Content-Type', 'text/plain') + c.putrequest("GET", "/") + c.putheader("Content-Type", "text/plain") # See https://www.bitbucket.org/cherrypy/cherrypy/issue/941 - c._output(b'Re, 1.2.3.4#015#012') + c._output(b"Re, 1.2.3.4#015#012") c.endheaders() response = c.getresponse() actual_status = response.status assert actual_status == HTTP_BAD_REQUEST actual_resp_body = response.read(20) - assert actual_resp_body == b'Illegal header line.' + assert actual_resp_body == b"Illegal header line." c.close() # deal with the resource warning @@ -371,18 +377,18 @@ def test_malformed_header(test_client): def test_request_line_split_issue_1220(test_client): """Check that HTTP request line of exactly 256 chars length is OK.""" Request_URI = ( - '/hello?' - 'intervenant-entreprise-evenement_classaction=' - 'evenement-mailremerciements' - '&_path=intervenant-entreprise-evenement' - '&intervenant-entreprise-evenement_action-id=19404' - '&intervenant-entreprise-evenement_id=19404' - '&intervenant-entreprise_id=28092' + "/hello?" + "intervenant-entreprise-evenement_classaction=" + "evenement-mailremerciements" + "&_path=intervenant-entreprise-evenement" + "&intervenant-entreprise-evenement_action-id=19404" + "&intervenant-entreprise-evenement_id=19404" + "&intervenant-entreprise_id=28092" ) - assert len('GET %s HTTP/1.1\r\n' % Request_URI) == 256 + assert len("GET %s HTTP/1.1\r\n" % Request_URI) == 256 actual_resp_body = test_client.get(Request_URI)[2] - assert actual_resp_body == b'Hello world!' + assert actual_resp_body == b"Hello world!" def test_garbage_in(test_client): @@ -390,15 +396,15 @@ def test_garbage_in(test_client): # Connect without SSL regardless of server.scheme c = test_client.get_connection() - c._output(b'gjkgjklsgjklsgjkljklsg') + c._output(b"gjkgjklsgjklsgjkljklsg") c._send_output() - response = c.response_class(c.sock, method='GET') + response = c.response_class(c.sock, method="GET") try: response.begin() actual_status = response.status assert actual_status == HTTP_BAD_REQUEST actual_resp_body = response.read(22) - assert actual_resp_body == b'Malformed Request-Line' + assert actual_resp_body == b"Malformed Request-Line" c.close() except socket.error as ex: # "Connection reset by peer" is also acceptable. @@ -418,7 +424,7 @@ def __call__(self, environ, start_response): def close(self): """Close, writing hello.""" - self.req.write(b'hello') + self.req.write(b"hello") class CloseResponse: @@ -426,8 +432,8 @@ class CloseResponse: def __init__(self, close): """Use some defaults to ensure we have a header.""" - self.status = '200 OK' - self.headers = {'Content-Type': 'text/html'} + self.status = "200 OK" + self.headers = {"Content-Type": "text/html"} self.close = close def __getitem__(self, index): @@ -451,5 +457,5 @@ def testing_server_close(wsgi_server_client): def test_send_header_before_closing(testing_server_close): """Test we are actually sending the headers before calling 'close'.""" - _, _, resp_body = testing_server_close.server_client.get('/') - assert resp_body == b'hello' + _, _, resp_body = testing_server_close.server_client.get("/") + assert resp_body == b"hello" diff --git a/cheroot/test/test_dispatch.py b/cheroot/test/test_dispatch.py index c42014face..290fb3eddb 100644 --- a/cheroot/test/test_dispatch.py +++ b/cheroot/test/test_dispatch.py @@ -8,12 +8,14 @@ def wsgi_invoke(app, environ): response = {} def start_response(status, headers): - response.update({ - 'status': status, - 'headers': headers, - }) + response.update( + { + "status": status, + "headers": headers, + } + ) - response['body'] = b''.join( + response["body"] = b"".join( app(environ, start_response), ) @@ -22,30 +24,35 @@ def start_response(status, headers): def test_dispatch_no_script_name(): """Dispatch despite lack of ``SCRIPT_NAME`` in environ.""" + # Bare bones WSGI hello world app (from PEP 333). def app(environ, start_response): start_response( - '200 OK', [ - ('Content-Type', 'text/plain; charset=utf-8'), + "200 OK", + [ + ("Content-Type", "text/plain; charset=utf-8"), ], ) - return [u'Hello, world!'.encode('utf-8')] + return ["Hello, world!".encode("utf-8")] # Build a dispatch table. - d = PathInfoDispatcher([ - ('/', app), - ]) + d = PathInfoDispatcher( + [ + ("/", app), + ] + ) # Dispatch a request without `SCRIPT_NAME`. response = wsgi_invoke( - d, { - 'PATH_INFO': '/foo', + d, + { + "PATH_INFO": "/foo", }, ) assert response == { - 'status': '200 OK', - 'headers': [ - ('Content-Type', 'text/plain; charset=utf-8'), + "status": "200 OK", + "headers": [ + ("Content-Type", "text/plain; charset=utf-8"), ], - 'body': b'Hello, world!', + "body": b"Hello, world!", } diff --git a/cheroot/test/test_errors.py b/cheroot/test/test_errors.py index a5dd5c2b0c..1df564a45b 100644 --- a/cheroot/test/test_errors.py +++ b/cheroot/test/test_errors.py @@ -8,19 +8,26 @@ @pytest.mark.parametrize( - ('err_names', 'err_nums'), + ("err_names", "err_nums"), ( - (('', 'some-nonsense-name'), []), + (("", "some-nonsense-name"), []), ( ( - 'EPROTOTYPE', 'EAGAIN', 'EWOULDBLOCK', - 'WSAEWOULDBLOCK', 'EPIPE', + "EPROTOTYPE", + "EAGAIN", + "EWOULDBLOCK", + "WSAEWOULDBLOCK", + "EPIPE", ), - (91, 11, 32) if IS_LINUX else - (32, 35, 41) if IS_MACOS else - (98, 11, 32) if IS_SOLARIS else - (32, 10041, 11, 10035) if IS_WINDOWS else - (), + (91, 11, 32) + if IS_LINUX + else (32, 35, 41) + if IS_MACOS + else (98, 11, 32) + if IS_SOLARIS + else (32, 10041, 11, 10035) + if IS_WINDOWS + else (), ), ), ) diff --git a/cheroot/test/test_makefile.py b/cheroot/test/test_makefile.py index 57f6f57ea9..dad3da9c70 100644 --- a/cheroot/test/test_makefile.py +++ b/cheroot/test/test_makefile.py @@ -24,7 +24,7 @@ def recv(self, size): try: return self.messages.pop(0) except IndexError: - return '' + return "" def send(self, val): """Simulate a send.""" @@ -34,8 +34,8 @@ def send(self, val): def test_bytes_read(): """Reader should capture bytes read.""" sock = MockSocket() - sock.messages.append(b'foo') - rfile = makefile.MakeFile(sock, 'r') + sock.messages.append(b"foo") + rfile = makefile.MakeFile(sock, "r") rfile.read() assert rfile.bytes_read == 3 @@ -43,7 +43,7 @@ def test_bytes_read(): def test_bytes_written(): """Writer should capture bytes written.""" sock = MockSocket() - sock.messages.append(b'foo') - wfile = makefile.MakeFile(sock, 'w') - wfile.write(b'bar') + sock.messages.append(b"foo") + wfile = makefile.MakeFile(sock, "w") + wfile.write(b"bar") assert wfile.bytes_written == 3 diff --git a/cheroot/test/test_server.py b/cheroot/test/test_server.py index 3c39773119..7982cd1f7e 100644 --- a/cheroot/test/test_server.py +++ b/cheroot/test/test_server.py @@ -30,21 +30,21 @@ unix_only_sock_test = pytest.mark.skipif( - not hasattr(socket, 'AF_UNIX'), - reason='UNIX domain sockets are only available under UNIX-based OS', + not hasattr(socket, "AF_UNIX"), + reason="UNIX domain sockets are only available under UNIX-based OS", ) non_macos_sock_test = pytest.mark.skipif( IS_MACOS, - reason='Peercreds lookup does not work under macOS/BSD currently.', + reason="Peercreds lookup does not work under macOS/BSD currently.", ) -@pytest.fixture(params=('abstract', 'file')) +@pytest.fixture(params=("abstract", "file")) def unix_sock_file(request): """Check that bound UNIX socket address is stored in server.""" - name = 'unix_{request.param}_sock'.format(**locals()) + name = "unix_{request.param}_sock".format(**locals()) return request.getfixturevalue(name) @@ -53,13 +53,16 @@ def unix_abstract_sock(): """Return an abstract UNIX socket address.""" if not IS_LINUX: pytest.skip( - '{os} does not support an abstract ' - 'socket namespace'.format(os=SYS_PLATFORM), + "{os} does not support an abstract " "socket namespace".format( + os=SYS_PLATFORM + ), ) - return b''.join(( - b'\x00cheroot-test-socket', - ntob(str(uuid.uuid4())), - )).decode() + return b"".join( + ( + b"\x00cheroot-test-socket", + ntob(str(uuid.uuid4())), + ) + ).decode() @pytest.fixture @@ -117,7 +120,7 @@ def test_stop_interrupts_serve(): @pytest.mark.parametrize( - 'exc_cls', + "exc_cls", ( IOError, KeyboardInterrupt, @@ -127,7 +130,7 @@ def test_stop_interrupts_serve(): ) def test_server_interrupt(exc_cls): """Check that assigning interrupt stops the server.""" - interrupt_msg = 'should catch {uuid!s}'.format(uuid=uuid.uuid4()) + interrupt_msg = "should catch {uuid!s}".format(uuid=uuid.uuid4()) raise_marker_sentinel = object() httpserver = HTTPServer( @@ -190,7 +193,7 @@ def raise_keyboard_interrupt(*args, **kwargs): @pytest.mark.parametrize( - 'ip_addr', + "ip_addr", ( ANY_INTERFACE_IPV4, ANY_INTERFACE_IPV6, @@ -220,8 +223,8 @@ def test_bind_addr_unix_abstract(http_server, unix_abstract_sock): assert httpserver.bind_addr == unix_abstract_sock -PEERCRED_IDS_URI = '/peer_creds/ids' -PEERCRED_TEXTS_URI = '/peer_creds/texts' +PEERCRED_IDS_URI = "/peer_creds/ids" +PEERCRED_TEXTS_URI = "/peer_creds/texts" class _TestGateway(Gateway): @@ -231,16 +234,16 @@ def respond(self): req_uri = bton(req.uri) if req_uri == PEERCRED_IDS_URI: peer_creds = conn.peer_pid, conn.peer_uid, conn.peer_gid - self.send_payload('|'.join(map(str, peer_creds))) + self.send_payload("|".join(map(str, peer_creds))) return elif req_uri == PEERCRED_TEXTS_URI: - self.send_payload('!'.join((conn.peer_user, conn.peer_group))) + self.send_payload("!".join((conn.peer_user, conn.peer_group))) return return super(_TestGateway, self).respond() def send_payload(self, payload): req = self.req - req.status = b'200 OK' + req.status = b"200 OK" req.ensure_headers_sent() req.write(ntob(payload)) @@ -266,11 +269,11 @@ def test_peercreds_unix_sock(http_request_timeout, peercreds_enabled_server): bind_addr = bind_addr.decode() # pylint: disable=possibly-unused-variable - quoted = urllib.parse.quote(bind_addr, safe='') - unix_base_uri = 'http+unix://{quoted}'.format(**locals()) + quoted = urllib.parse.quote(bind_addr, safe="") + unix_base_uri = "http+unix://{quoted}".format(**locals()) expected_peercreds = os.getpid(), os.getuid(), os.getgid() - expected_peercreds = '|'.join(map(str, expected_peercreds)) + expected_peercreds = "|".join(map(str, expected_peercreds)) with requests_unixsocket.monkeypatch(): peercreds_resp = requests.get( @@ -289,14 +292,13 @@ def test_peercreds_unix_sock(http_request_timeout, peercreds_enabled_server): @pytest.mark.skipif( not IS_UID_GID_RESOLVABLE, - reason='Modules `grp` and `pwd` are not available ' - 'under the current platform', + reason="Modules `grp` and `pwd` are not available " "under the current platform", ) @unix_only_sock_test @non_macos_sock_test def test_peercreds_unix_sock_with_lookup( - http_request_timeout, - peercreds_enabled_server, + http_request_timeout, + peercreds_enabled_server, ): """Check that ``PEERCRED`` resolution works when enabled.""" httpserver = peercreds_enabled_server @@ -308,16 +310,17 @@ def test_peercreds_unix_sock_with_lookup( bind_addr = bind_addr.decode() # pylint: disable=possibly-unused-variable - quoted = urllib.parse.quote(bind_addr, safe='') - unix_base_uri = 'http+unix://{quoted}'.format(**locals()) + quoted = urllib.parse.quote(bind_addr, safe="") + unix_base_uri = "http+unix://{quoted}".format(**locals()) import grp import pwd + expected_textcreds = ( pwd.getpwuid(os.getuid()).pw_name, grp.getgrgid(os.getgid()).gr_name, ) - expected_textcreds = '!'.join(map(str, expected_textcreds)) + expected_textcreds = "!".join(map(str, expected_textcreds)) with requests_unixsocket.monkeypatch(): peercreds_text_resp = requests.get( unix_base_uri + PEERCRED_TEXTS_URI, @@ -329,18 +332,18 @@ def test_peercreds_unix_sock_with_lookup( @pytest.mark.skipif( IS_WINDOWS, - reason='This regression test is for a Linux bug, ' - 'and the resource module is not available on Windows', + reason="This regression test is for a Linux bug, " + "and the resource module is not available on Windows", ) @pytest.mark.parametrize( - 'resource_limit', + "resource_limit", ( 1024, 2048, ), - indirect=('resource_limit',), + indirect=("resource_limit",), ) -@pytest.mark.usefixtures('many_open_sockets') +@pytest.mark.usefixtures("many_open_sockets") def test_high_number_of_file_descriptors(native_server_client, resource_limit): """Test the server does not crash with a high file-descriptor value. @@ -359,11 +362,12 @@ def test_high_number_of_file_descriptors(native_server_client, resource_limit): def native_process_conn(conn): native_process_conn.filenos.add(conn.socket.fileno()) return _old_process_conn(conn) + native_process_conn.filenos = set() native_server_client.server_instance.process_conn = native_process_conn # Trigger a crash if select() is used in the implementation - native_server_client.connect('/') + native_server_client.connect("/") # Ensure that at least one connection got accepted, otherwise the # follow-up check wouldn't make sense @@ -374,11 +378,11 @@ def native_process_conn(conn): @pytest.mark.skipif( - not hasattr(socket, 'SO_REUSEPORT'), - reason='socket.SO_REUSEPORT is not supported on this platform', + not hasattr(socket, "SO_REUSEPORT"), + reason="socket.SO_REUSEPORT is not supported on this platform", ) @pytest.mark.parametrize( - 'ip_addr', + "ip_addr", ( ANY_INTERFACE_IPV4, ANY_INTERFACE_IPV6, @@ -391,9 +395,11 @@ def test_reuse_port(http_server, ip_addr, mocker): s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) s.bind((ip_addr, EPHEMERAL_PORT)) server = HTTPServer( - bind_addr=s.getsockname()[:2], gateway=Gateway, reuse_port=True, + bind_addr=s.getsockname()[:2], + gateway=Gateway, + reuse_port=True, ) - spy = mocker.spy(server, 'prepare') + spy = mocker.spy(server, "prepare") server.prepare() server.stop() s.close() @@ -411,7 +417,7 @@ def _garbage_bin(): def resource_limit(request): """Set the resource limit two times bigger then requested.""" resource = pytest.importorskip( - 'resource', + "resource", reason='The "resource" module is Unix-specific', ) @@ -439,7 +445,7 @@ def many_open_sockets(request, resource_limit): # NOTE: `@pytest.mark.usefixtures` doesn't work on fixtures which # NOTE: forces us to invoke this one dynamically to avoid having an # NOTE: unused argument. - request.getfixturevalue('_garbage_bin') + request.getfixturevalue("_garbage_bin") # Hoard a lot of file descriptors by opening and storing a lot of sockets test_sockets = [] @@ -463,23 +469,23 @@ def many_open_sockets(request, resource_limit): @pytest.mark.parametrize( - ('minthreads', 'maxthreads', 'inited_maxthreads'), + ("minthreads", "maxthreads", "inited_maxthreads"), ( ( # NOTE: The docstring only mentions -1 to mean "no max", but other # NOTE: negative numbers should also work. 1, -2, - float('inf'), + float("inf"), ), - (1, -1, float('inf')), + (1, -1, float("inf")), (1, 1, 1), (1, 2, 2), - (1, float('inf'), float('inf')), - (2, -2, float('inf')), - (2, -1, float('inf')), + (1, float("inf"), float("inf")), + (2, -2, float("inf")), + (2, -1, float("inf")), (2, 2, 2), - (2, float('inf'), float('inf')), + (2, float("inf"), float("inf")), ), ) def test_threadpool_threadrange_set(minthreads, maxthreads, inited_maxthreads): @@ -498,21 +504,37 @@ def test_threadpool_threadrange_set(minthreads, maxthreads, inited_maxthreads): @pytest.mark.parametrize( - ('minthreads', 'maxthreads', 'error'), + ("minthreads", "maxthreads", "error"), ( - (-1, -1, 'min=-1 must be > 0'), - (-1, 0, 'min=-1 must be > 0'), - (-1, 1, 'min=-1 must be > 0'), - (-1, 2, 'min=-1 must be > 0'), - (0, -1, 'min=0 must be > 0'), - (0, 0, 'min=0 must be > 0'), - (0, 1, 'min=0 must be > 0'), - (0, 2, 'min=0 must be > 0'), - (1, 0, 'Expected an integer or the infinity value for the `max` argument but got 0.'), - (1, 0.5, 'Expected an integer or the infinity value for the `max` argument but got 0.5.'), - (2, 0, 'Expected an integer or the infinity value for the `max` argument but got 0.'), - (2, '1', "Expected an integer or the infinity value for the `max` argument but got '1'."), - (2, 1, 'max=1 must be > min=2'), + (-1, -1, "min=-1 must be > 0"), + (-1, 0, "min=-1 must be > 0"), + (-1, 1, "min=-1 must be > 0"), + (-1, 2, "min=-1 must be > 0"), + (0, -1, "min=0 must be > 0"), + (0, 0, "min=0 must be > 0"), + (0, 1, "min=0 must be > 0"), + (0, 2, "min=0 must be > 0"), + ( + 1, + 0, + "Expected an integer or the infinity value for the `max` argument but got 0.", + ), + ( + 1, + 0.5, + "Expected an integer or the infinity value for the `max` argument but got 0.5.", + ), + ( + 2, + 0, + "Expected an integer or the infinity value for the `max` argument but got 0.", + ), + ( + 2, + "1", + "Expected an integer or the infinity value for the `max` argument but got '1'.", + ), + (2, 1, "max=1 must be > min=2"), ), ) def test_threadpool_invalid_threadrange(minthreads, maxthreads, error): @@ -539,11 +561,11 @@ def test_threadpool_multistart_validation(monkeypatch): # actually starting any threads monkeypatch.setattr( ThreadPool, - '_spawn_worker', + "_spawn_worker", lambda _: types.SimpleNamespace(ready=True), ) tp = ThreadPool(server=None) tp.start() - with pytest.raises(RuntimeError, match='Threadpools can only be started once.'): + with pytest.raises(RuntimeError, match="Threadpools can only be started once."): tp.start() diff --git a/cheroot/test/test_ssl.py b/cheroot/test/test_ssl.py index 1900e20d15..1739d331aa 100644 --- a/cheroot/test/test_ssl.py +++ b/cheroot/test/test_ssl.py @@ -31,36 +31,36 @@ from ..wsgi import Gateway_10 -IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv('GITHUB_WORKFLOW')) +IS_GITHUB_ACTIONS_WORKFLOW = bool(os.getenv("GITHUB_WORKFLOW")) IS_WIN2016 = ( IS_WINDOWS # pylint: disable=unsupported-membership-test - and b'Microsoft Windows Server 2016 Datacenter' in subprocess.check_output( - ('systeminfo',), + and b"Microsoft Windows Server 2016 Datacenter" + in subprocess.check_output( + ("systeminfo",), ) ) -IS_LIBRESSL_BACKEND = ssl.OPENSSL_VERSION.startswith('LibreSSL') -IS_PYOPENSSL_SSL_VERSION_1_0 = ( - OpenSSL.SSL.SSLeay_version(OpenSSL.SSL.SSLEAY_VERSION). - startswith(b'OpenSSL 1.0.') -) +IS_LIBRESSL_BACKEND = ssl.OPENSSL_VERSION.startswith("LibreSSL") +IS_PYOPENSSL_SSL_VERSION_1_0 = OpenSSL.SSL.SSLeay_version( + OpenSSL.SSL.SSLEAY_VERSION +).startswith(b"OpenSSL 1.0.") PY310_PLUS = sys.version_info[:2] >= (3, 10) _stdlib_to_openssl_verify = { ssl.CERT_NONE: OpenSSL.SSL.VERIFY_NONE, ssl.CERT_OPTIONAL: OpenSSL.SSL.VERIFY_PEER, - ssl.CERT_REQUIRED: - OpenSSL.SSL.VERIFY_PEER + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, + ssl.CERT_REQUIRED: OpenSSL.SSL.VERIFY_PEER + + OpenSSL.SSL.VERIFY_FAIL_IF_NO_PEER_CERT, } missing_ipv6 = pytest.mark.skipif( - not _probe_ipv6_sock('::1'), - reason='' - 'IPv6 is disabled ' - '(for example, under Travis CI ' - 'which runs under GCE supporting only IPv4)', + not _probe_ipv6_sock("::1"), + reason="" + "IPv6 is disabled " + "(for example, under Travis CI " + "which runs under GCE supporting only IPv4)", ) @@ -71,20 +71,20 @@ def respond(self): """Respond with dummy content via HTTP.""" req = self.req req_uri = bton(req.uri) - if req_uri == '/': - req.status = b'200 OK' + if req_uri == "/": + req.status = b"200 OK" req.ensure_headers_sent() - req.write(b'Hello world!') + req.write(b"Hello world!") return - if req_uri == '/env': - req.status = b'200 OK' + if req_uri == "/env": + req.status = b"200 OK" req.ensure_headers_sent() env = self.get_environ() # drop files so that it can be json dumped - env.pop('wsgi.errors') - env.pop('wsgi.input') + env.pop("wsgi.errors") + env.pop("wsgi.input") print(env) - req.write(json.dumps(env).encode('utf-8')) + req.write(json.dumps(env).encode("utf-8")) return return super(HelloWorldGateway, self).respond() @@ -153,15 +153,19 @@ def _thread_except_hook(exceptions, args): if issubclass(args.exc_type, SystemExit): return # cannot store the exception, it references the thread's stack - exceptions.append(( - args.exc_type, - str(args.exc_value), - ''.join( - traceback.format_exception( - args.exc_type, args.exc_value, args.exc_traceback, + exceptions.append( + ( + args.exc_type, + str(args.exc_value), + "".join( + traceback.format_exception( + args.exc_type, + args.exc_value, + args.exc_traceback, + ), ), - ), - )) + ) + ) @pytest.fixture @@ -173,10 +177,11 @@ def thread_exceptions(): """ exceptions = [] # Python 3.8+ - orig_hook = getattr(threading, 'excepthook', None) + orig_hook = getattr(threading, "excepthook", None) if orig_hook is not None: threading.excepthook = functools.partial( - _thread_except_hook, exceptions, + _thread_except_hook, + exceptions, ) try: yield exceptions @@ -186,15 +191,16 @@ def thread_exceptions(): @pytest.mark.parametrize( - 'adapter_type', + "adapter_type", ( - 'builtin', - 'pyopenssl', + "builtin", + "pyopenssl", ), ) def test_ssl_adapters( http_request_timeout, - tls_http_server, adapter_type, + tls_http_server, + adapter_type, tls_certificate, tls_certificate_chain_pem_path, tls_certificate_private_key_pem_path, @@ -204,9 +210,10 @@ def test_ssl_adapters( interface, _host, port = _get_conn_data(ANY_INTERFACE_IPV4) tls_adapter_cls = get_ssl_adapter_class(name=adapter_type) tls_adapter = tls_adapter_cls( - tls_certificate_chain_pem_path, tls_certificate_private_key_pem_path, + tls_certificate_chain_pem_path, + tls_certificate_private_key_pem_path, ) - if adapter_type == 'pyopenssl': + if adapter_type == "pyopenssl": tls_adapter.context = tls_adapter.get_context() tls_certificate.configure_cert(tls_adapter.context) @@ -221,32 +228,34 @@ def test_ssl_adapters( ) resp = requests.get( - 'https://{host!s}:{port!s}/'.format(host=interface, port=port), + "https://{host!s}:{port!s}/".format(host=interface, port=port), timeout=http_request_timeout, verify=tls_ca_certificate_pem_path, ) assert resp.status_code == 200 - assert resp.text == 'Hello world!' + assert resp.text == "Hello world!" @pytest.mark.parametrize( # noqa: C901 # FIXME - 'adapter_type', + "adapter_type", ( - 'builtin', - 'pyopenssl', + "builtin", + "pyopenssl", ), ) @pytest.mark.parametrize( - ('is_trusted_cert', 'tls_client_identity'), + ("is_trusted_cert", "tls_client_identity"), ( - (True, 'localhost'), (True, '127.0.0.1'), - (True, '*.localhost'), (True, 'not_localhost'), - (False, 'localhost'), + (True, "localhost"), + (True, "127.0.0.1"), + (True, "*.localhost"), + (True, "not_localhost"), + (False, "localhost"), ), ) @pytest.mark.parametrize( - 'tls_verify_mode', + "tls_verify_mode", ( ssl.CERT_NONE, # server shouldn't validate client cert ssl.CERT_OPTIONAL, # same as CERT_REQUIRED in client mode, don't use @@ -255,32 +264,31 @@ def test_ssl_adapters( ) @pytest.mark.xfail( IS_PYPY and IS_CI, - reason='Fails under PyPy in CI for unknown reason', + reason="Fails under PyPy in CI for unknown reason", strict=False, ) def test_tls_client_auth( # noqa: C901, WPS213 # FIXME # FIXME: remove twisted logic, separate tests http_request_timeout, mocker, - tls_http_server, adapter_type, + tls_http_server, + adapter_type, ca, tls_certificate, tls_certificate_chain_pem_path, tls_certificate_private_key_pem_path, tls_ca_certificate_pem_path, - is_trusted_cert, tls_client_identity, + is_trusted_cert, + tls_client_identity, tls_verify_mode, ): """Verify that client TLS certificate auth works correctly.""" - test_cert_rejection = ( - tls_verify_mode != ssl.CERT_NONE - and not is_trusted_cert - ) + test_cert_rejection = tls_verify_mode != ssl.CERT_NONE and not is_trusted_cert interface, _host, port = _get_conn_data(ANY_INTERFACE_IPV4) client_cert_root_ca = ca if is_trusted_cert else trustme.CA() with mocker.mock_module.patch( - 'idna.core.ulabel', + "idna.core.ulabel", return_value=ntob(tls_client_identity), ): client_cert = client_cert_root_ca.issue_cert( @@ -294,7 +302,7 @@ def test_tls_client_auth( # noqa: C901, WPS213 # FIXME tls_certificate_chain_pem_path, tls_certificate_private_key_pem_path, ) - if adapter_type == 'pyopenssl': + if adapter_type == "pyopenssl": tls_adapter.context = tls_adapter.get_context() tls_adapter.context.set_verify( _stdlib_to_openssl_verify[tls_verify_mode], @@ -312,14 +320,11 @@ def test_tls_client_auth( # noqa: C901, WPS213 # FIXME make_https_request = functools.partial( requests.get, - 'https://{host!s}:{port!s}/'.format(host=interface, port=port), - + "https://{host!s}:{port!s}/".format(host=interface, port=port), # Don't wait for the first byte forever: timeout=http_request_timeout, - # Server TLS certificate verification: verify=tls_ca_certificate_pem_path, - # Client TLS certificate verification: cert=cl_pem, ) @@ -328,34 +333,32 @@ def test_tls_client_auth( # noqa: C901, WPS213 # FIXME resp = make_https_request() is_req_successful = resp.status_code == 200 if ( - not is_req_successful - and IS_PYOPENSSL_SSL_VERSION_1_0 - and adapter_type == 'builtin' - and tls_verify_mode == ssl.CERT_REQUIRED - and tls_client_identity == 'localhost' - and is_trusted_cert + not is_req_successful + and IS_PYOPENSSL_SSL_VERSION_1_0 + and adapter_type == "builtin" + and tls_verify_mode == ssl.CERT_REQUIRED + and tls_client_identity == "localhost" + and is_trusted_cert ): pytest.xfail( - 'OpenSSL 1.0 has problems with verifying client certs', + "OpenSSL 1.0 has problems with verifying client certs", ) assert is_req_successful - assert resp.text == 'Hello world!' + assert resp.text == "Hello world!" resp.close() return # xfail some flaky tests # https://github.com/cherrypy/cheroot/issues/237 issue_237 = ( - IS_MACOS - and adapter_type == 'builtin' - and tls_verify_mode != ssl.CERT_NONE + IS_MACOS and adapter_type == "builtin" and tls_verify_mode != ssl.CERT_NONE ) if issue_237: - pytest.xfail('Test sometimes fails') + pytest.xfail("Test sometimes fails") - expected_ssl_errors = requests.exceptions.SSLError, + expected_ssl_errors = (requests.exceptions.SSLError,) if IS_WINDOWS or IS_GITHUB_ACTIONS_WORKFLOW: - expected_ssl_errors += requests.exceptions.ConnectionError, + expected_ssl_errors += (requests.exceptions.ConnectionError,) with pytest.raises(expected_ssl_errors) as ssl_err: make_https_request().close() @@ -371,65 +374,61 @@ def test_tls_client_auth( # noqa: C901, WPS213 # FIXME err_text = str(ssl_err.value) expected_substrings = ( - 'sslv3 alert bad certificate' if IS_LIBRESSL_BACKEND - else 'tlsv1 alert unknown ca', + "sslv3 alert bad certificate" + if IS_LIBRESSL_BACKEND + else "tlsv1 alert unknown ca", ) - if IS_MACOS and IS_PYPY and adapter_type == 'pyopenssl': - expected_substrings = ('tlsv1 alert unknown ca',) + if IS_MACOS and IS_PYPY and adapter_type == "pyopenssl": + expected_substrings = ("tlsv1 alert unknown ca",) if ( - tls_verify_mode in ( - ssl.CERT_REQUIRED, - ssl.CERT_OPTIONAL, - ) - and not is_trusted_cert - and tls_client_identity == 'localhost' + tls_verify_mode + in ( + ssl.CERT_REQUIRED, + ssl.CERT_OPTIONAL, + ) + and not is_trusted_cert + and tls_client_identity == "localhost" ): expected_substrings += ( - 'bad handshake: ' - "SysCallError(10054, 'WSAECONNRESET')", - "('Connection aborted.', " - 'OSError("(10054, \'WSAECONNRESET\')"))', - "('Connection aborted.', " - 'OSError("(10054, \'WSAECONNRESET\')",))', - "('Connection aborted.', " - 'error("(10054, \'WSAECONNRESET\')",))', - "('Connection aborted.', " - 'ConnectionResetError(10054, ' - "'An existing connection was forcibly closed " - "by the remote host', None, 10054, None))", - "('Connection aborted.', " - 'error(10054, ' - "'An existing connection was forcibly closed " - "by the remote host'))", - ) if IS_WINDOWS else ( - "('Connection aborted.', " - 'OSError("(104, \'ECONNRESET\')"))', - "('Connection aborted.', " - 'OSError("(104, \'ECONNRESET\')",))', - "('Connection aborted.', " - 'error("(104, \'ECONNRESET\')",))', - "('Connection aborted.', " - "ConnectionResetError(104, 'Connection reset by peer'))", - "('Connection aborted.', " - "error(104, 'Connection reset by peer'))", - ) if ( - IS_GITHUB_ACTIONS_WORKFLOW - and IS_LINUX - ) else ( - "('Connection aborted.', " - "BrokenPipeError(32, 'Broken pipe'))", + ( + "bad handshake: " "SysCallError(10054, 'WSAECONNRESET')", + "('Connection aborted.', " "OSError(\"(10054, 'WSAECONNRESET')\"))", + "('Connection aborted.', " + "OSError(\"(10054, 'WSAECONNRESET')\",))", + "('Connection aborted.', " "error(\"(10054, 'WSAECONNRESET')\",))", + "('Connection aborted.', " + "ConnectionResetError(10054, " + "'An existing connection was forcibly closed " + "by the remote host', None, 10054, None))", + "('Connection aborted.', " + "error(10054, " + "'An existing connection was forcibly closed " + "by the remote host'))", + ) + if IS_WINDOWS + else ( + "('Connection aborted.', " "OSError(\"(104, 'ECONNRESET')\"))", + "('Connection aborted.', " "OSError(\"(104, 'ECONNRESET')\",))", + "('Connection aborted.', " "error(\"(104, 'ECONNRESET')\",))", + "('Connection aborted.', " + "ConnectionResetError(104, 'Connection reset by peer'))", + "('Connection aborted.', " + "error(104, 'Connection reset by peer'))", + ) + if (IS_GITHUB_ACTIONS_WORKFLOW and IS_LINUX) + else ("('Connection aborted.', " "BrokenPipeError(32, 'Broken pipe'))",) ) if PY310_PLUS: # FIXME: Figure out what's happening and correct the problem expected_substrings += ( - 'SSLError(SSLEOFError(8, ' + "SSLError(SSLEOFError(8, " "'EOF occurred in violation of protocol (_ssl.c:", ) if IS_GITHUB_ACTIONS_WORKFLOW and IS_WINDOWS and PY310_PLUS: expected_substrings += ( "('Connection aborted.', " - 'RemoteDisconnected(' + "RemoteDisconnected(" "'Remote end closed connection without response'))", ) @@ -437,22 +436,22 @@ def test_tls_client_auth( # noqa: C901, WPS213 # FIXME @pytest.mark.parametrize( # noqa: C901 # FIXME - 'adapter_type', + "adapter_type", ( pytest.param( - 'builtin', + "builtin", marks=pytest.mark.xfail( IS_MACOS and PY310_PLUS, - reason='Unclosed TLS resource warnings happen on macOS ' - 'under Python 3.10 (#508)', + reason="Unclosed TLS resource warnings happen on macOS " + "under Python 3.10 (#508)", strict=False, ), ), - 'pyopenssl', + "pyopenssl", ), ) @pytest.mark.parametrize( - ('tls_verify_mode', 'use_client_cert'), + ("tls_verify_mode", "use_client_cert"), ( (ssl.CERT_NONE, False), (ssl.CERT_NONE, True), @@ -462,25 +461,28 @@ def test_tls_client_auth( # noqa: C901, WPS213 # FIXME ), ) def test_ssl_env( # noqa: C901 # FIXME - thread_exceptions, - recwarn, - mocker, - http_request_timeout, - tls_http_server, adapter_type, - ca, tls_verify_mode, tls_certificate, - tls_certificate_chain_pem_path, - tls_certificate_private_key_pem_path, - tls_ca_certificate_pem_path, - use_client_cert, + thread_exceptions, + recwarn, + mocker, + http_request_timeout, + tls_http_server, + adapter_type, + ca, + tls_verify_mode, + tls_certificate, + tls_certificate_chain_pem_path, + tls_certificate_private_key_pem_path, + tls_ca_certificate_pem_path, + use_client_cert, ): """Test the SSL environment generated by the SSL adapters.""" interface, _host, port = _get_conn_data(ANY_INTERFACE_IPV4) with mocker.mock_module.patch( - 'idna.core.ulabel', - return_value=ntob('127.0.0.1'), + "idna.core.ulabel", + return_value=ntob("127.0.0.1"), ): - client_cert = ca.issue_cert(ntou('127.0.0.1')) + client_cert = ca.issue_cert(ntou("127.0.0.1")) with client_cert.private_key_and_cert_chain_pem.tempfile() as cl_pem: tls_adapter_cls = get_ssl_adapter_class(name=adapter_type) @@ -488,7 +490,7 @@ def test_ssl_env( # noqa: C901 # FIXME tls_certificate_chain_pem_path, tls_certificate_private_key_pem_path, ) - if adapter_type == 'pyopenssl': + if adapter_type == "pyopenssl": tls_adapter.context = tls_adapter.get_context() tls_adapter.context.set_verify( _stdlib_to_openssl_verify[tls_verify_mode], @@ -505,41 +507,43 @@ def test_ssl_env( # noqa: C901 # FIXME interface, _host, port = _get_conn_data(tlswsgiserver.bind_addr) resp = requests.get( - 'https://' + interface + ':' + str(port) + '/env', + "https://" + interface + ":" + str(port) + "/env", timeout=http_request_timeout, verify=tls_ca_certificate_pem_path, cert=cl_pem if use_client_cert else None, ) - env = json.loads(resp.content.decode('utf-8')) + env = json.loads(resp.content.decode("utf-8")) # hard coded env - assert env['wsgi.url_scheme'] == 'https' - assert env['HTTPS'] == 'on' + assert env["wsgi.url_scheme"] == "https" + assert env["HTTPS"] == "on" # ensure these are present - for key in {'SSL_VERSION_INTERFACE', 'SSL_VERSION_LIBRARY'}: + for key in {"SSL_VERSION_INTERFACE", "SSL_VERSION_LIBRARY"}: assert key in env # pyOpenSSL generates the env before the handshake completes - if adapter_type == 'pyopenssl': + if adapter_type == "pyopenssl": return - for key in {'SSL_PROTOCOL', 'SSL_CIPHER'}: + for key in {"SSL_PROTOCOL", "SSL_CIPHER"}: assert key in env # client certificate env if tls_verify_mode == ssl.CERT_NONE or not use_client_cert: - assert env['SSL_CLIENT_VERIFY'] == 'NONE' + assert env["SSL_CLIENT_VERIFY"] == "NONE" else: - assert env['SSL_CLIENT_VERIFY'] == 'SUCCESS' + assert env["SSL_CLIENT_VERIFY"] == "SUCCESS" - with open(cl_pem, 'rt') as f: - assert env['SSL_CLIENT_CERT'] in f.read() + with open(cl_pem, "rt") as f: + assert env["SSL_CLIENT_CERT"] in f.read() for key in { - 'SSL_CLIENT_M_VERSION', 'SSL_CLIENT_M_SERIAL', - 'SSL_CLIENT_I_DN', 'SSL_CLIENT_S_DN', + "SSL_CLIENT_M_VERSION", + "SSL_CLIENT_M_SERIAL", + "SSL_CLIENT_I_DN", + "SSL_CLIENT_S_DN", }: assert key in env @@ -558,13 +562,15 @@ def test_ssl_env( # noqa: C901 # FIXME # all of these sporadic warnings appear to be about socket.socket # and have been observed to come from requests connection pool msg = str(warn.message) - if 'socket.socket' in msg: + if "socket.socket" in msg: pytest.xfail( - '\n'.join(( - 'Sometimes this test fails due to ' - 'a socket.socket ResourceWarning:', - msg, - )), + "\n".join( + ( + "Sometimes this test fails due to " + "a socket.socket ResourceWarning:", + msg, + ) + ), ) pytest.fail(msg) @@ -572,14 +578,16 @@ def test_ssl_env( # noqa: C901 # FIXME # the builtin ssl environment generation uses a thread for _, _, trace in thread_exceptions: print(trace, file=sys.stderr) - assert not thread_exceptions, ': '.join(( - thread_exceptions[0][0].__name__, - thread_exceptions[0][1], - )) + assert not thread_exceptions, ": ".join( + ( + thread_exceptions[0][0].__name__, + thread_exceptions[0][1], + ) + ) @pytest.mark.parametrize( - 'ip_addr', + "ip_addr", ( ANY_INTERFACE_IPV4, ANY_INTERFACE_IPV6, @@ -591,27 +599,26 @@ def test_https_over_http_error(http_server, ip_addr): interface, _host, port = _get_conn_data(httpserver.bind_addr) with pytest.raises(ssl.SSLError) as ssl_err: http.client.HTTPSConnection( - '{interface}:{port}'.format( + "{interface}:{port}".format( interface=interface, port=port, ), - ).request('GET', '/') + ).request("GET", "/") expected_substring = ( - 'wrong version number' if IS_ABOVE_OPENSSL10 - else 'unknown protocol' + "wrong version number" if IS_ABOVE_OPENSSL10 else "unknown protocol" ) assert expected_substring in ssl_err.value.args[-1] @pytest.mark.parametrize( - 'adapter_type', + "adapter_type", ( - 'builtin', - 'pyopenssl', + "builtin", + "pyopenssl", ), ) @pytest.mark.parametrize( - 'ip_addr', + "ip_addr", ( ANY_INTERFACE_IPV4, pytest.param(ANY_INTERFACE_IPV6, marks=missing_ipv6), @@ -620,8 +627,10 @@ def test_https_over_http_error(http_server, ip_addr): @pytest.mark.flaky(reruns=3, reruns_delay=2) def test_http_over_https_error( http_request_timeout, - tls_http_server, adapter_type, - ca, ip_addr, + tls_http_server, + adapter_type, + ca, + ip_addr, tls_certificate, tls_certificate_chain_pem_path, tls_certificate_private_key_pem_path, @@ -629,18 +638,16 @@ def test_http_over_https_error( """Ensure that connecting over HTTP to HTTPS port is handled.""" # disable some flaky tests # https://github.com/cherrypy/cheroot/issues/225 - issue_225 = ( - IS_MACOS - and adapter_type == 'builtin' - ) + issue_225 = IS_MACOS and adapter_type == "builtin" if issue_225: - pytest.xfail('Test fails in Travis-CI') + pytest.xfail("Test fails in Travis-CI") tls_adapter_cls = get_ssl_adapter_class(name=adapter_type) tls_adapter = tls_adapter_cls( - tls_certificate_chain_pem_path, tls_certificate_private_key_pem_path, + tls_certificate_chain_pem_path, + tls_certificate_private_key_pem_path, ) - if adapter_type == 'pyopenssl': + if adapter_type == "pyopenssl": tls_adapter.context = tls_adapter.get_context() tls_certificate.configure_cert(tls_adapter.context) @@ -654,49 +661,46 @@ def test_http_over_https_error( fqdn = interface if ip_addr is ANY_INTERFACE_IPV6: - fqdn = '[{fqdn}]'.format(**locals()) + fqdn = "[{fqdn}]".format(**locals()) - expect_fallback_response_over_plain_http = ( - ( - adapter_type == 'pyopenssl' - ) - ) + expect_fallback_response_over_plain_http = adapter_type == "pyopenssl" if expect_fallback_response_over_plain_http: resp = requests.get( - 'http://{host!s}:{port!s}/'.format(host=fqdn, port=port), + "http://{host!s}:{port!s}/".format(host=fqdn, port=port), timeout=http_request_timeout, ) assert resp.status_code == 400 assert resp.text == ( - 'The client sent a plain HTTP request, ' - 'but this server only speaks HTTPS on this port.' + "The client sent a plain HTTP request, " + "but this server only speaks HTTPS on this port." ) return with pytest.raises(requests.exceptions.ConnectionError) as ssl_err: requests.get( # FIXME: make stdlib ssl behave like PyOpenSSL - 'http://{host!s}:{port!s}/'.format(host=fqdn, port=port), + "http://{host!s}:{port!s}/".format(host=fqdn, port=port), timeout=http_request_timeout, ) if IS_LINUX: expected_error_code, expected_error_text = ( - 104, 'Connection reset by peer', + 104, + "Connection reset by peer", ) if IS_MACOS: expected_error_code, expected_error_text = ( - 54, 'Connection reset by peer', + 54, + "Connection reset by peer", ) if IS_WINDOWS: expected_error_code, expected_error_text = ( 10054, - 'An existing connection was forcibly closed by the remote host', + "An existing connection was forcibly closed by the remote host", ) underlying_error = ssl_err.value.args[0].args[-1] err_text = str(underlying_error) - assert underlying_error.errno == expected_error_code, ( - 'The underlying error is {underlying_error!r}'. - format(**locals()) - ) + assert ( + underlying_error.errno == expected_error_code + ), "The underlying error is {underlying_error!r}".format(**locals()) assert expected_error_text in err_text diff --git a/cheroot/test/test_wsgi.py b/cheroot/test/test_wsgi.py index 14005a84ee..ba0e9d636c 100644 --- a/cheroot/test/test_wsgi.py +++ b/cheroot/test/test_wsgi.py @@ -22,16 +22,16 @@ def simple_wsgi_server(): port = portend.find_available_local_port() def app(_environ, start_response): - status = '200 OK' - response_headers = [('Content-type', 'text/plain')] + status = "200 OK" + response_headers = [("Content-type", "text/plain")] start_response(status, response_headers) - return [b'Hello world!'] + return [b"Hello world!"] - host = '::' + host = "::" addr = host, port server = wsgi.Server(addr, app, timeout=600 if IS_SLOW_ENV else 20) # pylint: disable=possibly-unused-variable - url = 'http://localhost:{port}/'.format(**locals()) + url = "http://localhost:{port}/".format(**locals()) # pylint: disable=possibly-unused-variable with server._run_in_thread() as thread: yield locals() @@ -40,24 +40,22 @@ def app(_environ, start_response): @pytest.mark.flaky(reruns=3, reruns_delay=2) def test_connection_keepalive(simple_wsgi_server): """Test the connection keepalive works (duh).""" - session = Session(base_url=simple_wsgi_server['url']) + session = Session(base_url=simple_wsgi_server["url"]) pooled = requests.adapters.HTTPAdapter( - pool_connections=1, pool_maxsize=1000, + pool_connections=1, + pool_maxsize=1000, ) - session.mount('http://', pooled) + session.mount("http://", pooled) def do_request(): with ExceptionTrap(requests.exceptions.ConnectionError) as trap: - resp = session.get('info') + resp = session.get("info") resp.raise_for_status() print_tb(trap.tb) return bool(trap) with ThreadPoolExecutor(max_workers=10 if IS_SLOW_ENV else 50) as pool: - tasks = [ - pool.submit(do_request) - for n in range(250 if IS_SLOW_ENV else 1000) - ] + tasks = [pool.submit(do_request) for n in range(250 if IS_SLOW_ENV else 1000)] failures = sum(task.result() for task in tasks) session.close() @@ -66,20 +64,20 @@ def do_request(): def test_gateway_start_response_called_twice(monkeypatch): """Verify that repeat calls of ``Gateway.start_response()`` fail.""" - monkeypatch.setattr(wsgi.Gateway, 'get_environ', lambda self: {}) + monkeypatch.setattr(wsgi.Gateway, "get_environ", lambda self: {}) wsgi_gateway = wsgi.Gateway(None) wsgi_gateway.started_response = True - err_msg = '^WSGI start_response called a second time with no exc_info.$' + err_msg = "^WSGI start_response called a second time with no exc_info.$" with pytest.raises(RuntimeError, match=err_msg): - wsgi_gateway.start_response('200', (), None) + wsgi_gateway.start_response("200", (), None) def test_gateway_write_needs_start_response_called_before(monkeypatch): """Check that calling ``Gateway.write()`` needs started response.""" - monkeypatch.setattr(wsgi.Gateway, 'get_environ', lambda self: {}) + monkeypatch.setattr(wsgi.Gateway, "get_environ", lambda self: {}) wsgi_gateway = wsgi.Gateway(None) - err_msg = '^WSGI write called before start_response.$' + err_msg = "^WSGI write called before start_response.$" with pytest.raises(RuntimeError, match=err_msg): wsgi_gateway.write(None) # The actual arg value is unimportant diff --git a/cheroot/test/webtest.py b/cheroot/test/webtest.py index 7005a0a373..662670a5ae 100644 --- a/cheroot/test/webtest.py +++ b/cheroot/test/webtest.py @@ -39,18 +39,19 @@ def interface(host): If the server is listening on '0.0.0.0' (INADDR_ANY) or '::' (IN6ADDR_ANY), this will return the proper localhost. """ - if host == '0.0.0.0': + if host == "0.0.0.0": # INADDR_ANY, which should respond on localhost. - return '127.0.0.1' - if host == '::': + return "127.0.0.1" + if host == "::": # IN6ADDR_ANY, which should respond on localhost. - return '::1' + return "::1" return host try: # Jython support - if sys.platform[:4] == 'java': + if sys.platform[:4] == "java": + def getchar(): """Get a key press.""" # Hopefully this is enough @@ -85,8 +86,8 @@ class NonDataProperty: def __init__(self, fget): """Initialize a non-data property.""" - assert fget is not None, 'fget cannot be none' - assert callable(fget), 'fget must be callable' + assert fget is not None, "fget cannot be none" + assert callable(fget), "fget must be callable" self.fget = fget def __get__(self, obj, objtype=None): @@ -99,12 +100,12 @@ def __get__(self, obj, objtype=None): class WebCase(unittest.TestCase): """Helper web test suite base.""" - HOST = '127.0.0.1' + HOST = "127.0.0.1" PORT = 8000 HTTP_CONN = http.client.HTTPConnection - PROTOCOL = 'HTTP/1.1' + PROTOCOL = "HTTP/1.1" - scheme = 'http' + scheme = "http" url = None ssl_context = None @@ -112,7 +113,7 @@ class WebCase(unittest.TestCase): headers = None body = None - encoding = 'utf-8' + encoding = "utf-8" time = None @@ -122,7 +123,7 @@ def _Conn(self): * from :py:mod:`python:http.client`. """ - cls_name = '{scheme}Connection'.format(scheme=self.scheme.upper()) + cls_name = "{scheme}Connection".format(scheme=self.scheme.upper()) return getattr(http.client, cls_name) def get_conn(self, auto_open=False): @@ -147,16 +148,12 @@ def set_persistent(self, on=True, auto_open=False): except (TypeError, AttributeError): pass - self.HTTP_CONN = ( - self.get_conn(auto_open=auto_open) - if on - else self._Conn - ) + self.HTTP_CONN = self.get_conn(auto_open=auto_open) if on else self._Conn @property def persistent(self): """Presence of the persistent HTTP connection.""" - return hasattr(self.HTTP_CONN, '__class__') + return hasattr(self.HTTP_CONN, "__class__") @persistent.setter def persistent(self, on): @@ -171,8 +168,13 @@ def interface(self): return interface(self.HOST) def getPage( - self, url, headers=None, method='GET', body=None, - protocol=None, raise_subcls=(), + self, + url, + headers=None, + method="GET", + body=None, + protocol=None, + raise_subcls=(), ): """Open the url with debugging support. @@ -198,9 +200,9 @@ def getPage( ServerError.on = False if isinstance(url, str): - url = url.encode('utf-8') + url = url.encode("utf-8") if isinstance(body, str): - body = body.encode('utf-8') + body = body.encode("utf-8") # for compatibility, support raise_subcls is None raise_subcls = raise_subcls or () @@ -209,8 +211,14 @@ def getPage( self.time = None start = time.time() result = openURL( - url, headers, method, body, self.HOST, self.PORT, - self.HTTP_CONN, protocol or self.PROTOCOL, + url, + headers, + method, + body, + self.HOST, + self.PORT, + self.HTTP_CONN, + protocol or self.PROTOCOL, raise_subcls=raise_subcls, ssl_context=self.ssl_context, ) @@ -219,8 +227,7 @@ def getPage( # Build a list of request cookies from the previous response cookies. self.cookies = [ - ('Cookie', v) for k, v in self.headers - if k.lower() == 'set-cookie' + ("Cookie", v) for k, v in self.headers if k.lower() == "set-cookie" ] if ServerError.on: @@ -235,12 +242,12 @@ def interactive(self): the value can be numeric or a string like true or False or 1 or 0. """ - env_str = os.environ.get('WEBTEST_INTERACTIVE', 'True') + env_str = os.environ.get("WEBTEST_INTERACTIVE", "True") is_interactive = bool(json.loads(env_str.lower())) if is_interactive: warnings.warn( - 'Interactive test failure interceptor support via ' - 'WEBTEST_INTERACTIVE environment variable is deprecated.', + "Interactive test failure interceptor support via " + "WEBTEST_INTERACTIVE environment variable is deprecated.", DeprecationWarning, stacklevel=1, ) @@ -249,49 +256,49 @@ def interactive(self): console_height = 30 def _handlewebError(self, msg): # noqa: C901 # FIXME - print('') - print(' ERROR: %s' % msg) + print("") + print(" ERROR: %s" % msg) if not self.interactive: raise self.failureException(msg) p = ( - ' Show: ' - '[B]ody [H]eaders [S]tatus [U]RL; ' - '[I]gnore, [R]aise, or sys.e[X]it >> ' + " Show: " + "[B]ody [H]eaders [S]tatus [U]RL; " + "[I]gnore, [R]aise, or sys.e[X]it >> " ) sys.stdout.write(p) sys.stdout.flush() while True: i = getchar().upper() - if not isinstance(i, type('')): - i = i.decode('ascii') - if i not in 'BHSUIRX': + if not isinstance(i, type("")): + i = i.decode("ascii") + if i not in "BHSUIRX": continue print(i.upper()) # Also prints new line - if i == 'B': + if i == "B": for x, line in enumerate(self.body.splitlines()): if (x + 1) % self.console_height == 0: # The \r and comma should make the next line overwrite - sys.stdout.write('<-- More -->\r') + sys.stdout.write("<-- More -->\r") m = getchar().lower() # Erase our "More" prompt - sys.stdout.write(' \r') - if m == 'q': + sys.stdout.write(" \r") + if m == "q": break print(line) - elif i == 'H': + elif i == "H": pprint.pprint(self.headers) - elif i == 'S': + elif i == "S": print(self.status) - elif i == 'U': + elif i == "U": print(self.url) - elif i == 'I': + elif i == "I": # return without raising the normal exception return - elif i == 'R': + elif i == "R": raise self.failureException(msg) - elif i == 'X': + elif i == "X": sys.exit() sys.stdout.write(p) sys.stdout.flush() @@ -303,11 +310,7 @@ def status_code(self): # noqa: D401; irrelevant for properties def status_matches(self, expected): """Check whether actual status matches expected.""" - actual = ( - self.status_code - if isinstance(expected, int) else - self.status - ) + actual = self.status_code if isinstance(expected, int) else self.status return expected == actual def assertStatus(self, status, msg=None): @@ -319,7 +322,7 @@ def assertStatus(self, status, msg=None): if any(map(self.status_matches, always_iterable(status))): return - tmpl = 'Status {self.status} does not match {status}' + tmpl = "Status {self.status} does not match {status}" msg = msg or tmpl.format(**locals()) self._handlewebError(msg) @@ -333,9 +336,9 @@ def assertHeader(self, key, value=None, msg=None): if msg is None: if value is None: - msg = '%r not in headers' % key + msg = "%r not in headers" % key else: - msg = '%r:%r not in headers' % (key, value) + msg = "%r:%r not in headers" % (key, value) self._handlewebError(msg) def assertHeaderIn(self, key, values, msg=None): @@ -348,18 +351,18 @@ def assertHeaderIn(self, key, values, msg=None): return matches if msg is None: - msg = '%(key)r not in %(values)r' % vars() + msg = "%(key)r not in %(values)r" % vars() self._handlewebError(msg) def assertHeaderItemValue(self, key, value, msg=None): """Fail if the header does not contain the specified value.""" actual_value = self.assertHeader(key, msg=msg) - header_values = map(str.strip, actual_value.split(',')) + header_values = map(str.strip, actual_value.split(",")) if value in header_values: return value if msg is None: - msg = '%r not in %r' % (value, header_values) + msg = "%r not in %r" % (value, header_values) self._handlewebError(msg) def assertNoHeader(self, key, msg=None): @@ -368,7 +371,7 @@ def assertNoHeader(self, key, msg=None): matches = [k for k, v in self.headers if k.lower() == lowkey] if matches: if msg is None: - msg = '%r in headers' % key + msg = "%r in headers" % key self._handlewebError(msg) def assertNoHeaderItemValue(self, key, value, msg=None): @@ -378,7 +381,7 @@ def assertNoHeaderItemValue(self, key, value, msg=None): matches = [k for k, v in hdrs if k.lower() == lowkey and v == value] if matches: if msg is None: - msg = '%r:%r in %r' % (key, value, hdrs) + msg = "%r:%r in %r" % (key, value, hdrs) self._handlewebError(msg) def assertBody(self, value, msg=None): @@ -387,8 +390,9 @@ def assertBody(self, value, msg=None): value = value.encode(self.encoding) if value != self.body: if msg is None: - msg = 'expected body:\n%r\n\nactual body:\n%r' % ( - value, self.body, + msg = "expected body:\n%r\n\nactual body:\n%r" % ( + value, + self.body, ) self._handlewebError(msg) @@ -398,7 +402,7 @@ def assertInBody(self, value, msg=None): value = value.encode(self.encoding) if value not in self.body: if msg is None: - msg = '%r not in body: %s' % (value, self.body) + msg = "%r not in body: %s" % (value, self.body) self._handlewebError(msg) def assertNotInBody(self, value, msg=None): @@ -407,7 +411,7 @@ def assertNotInBody(self, value, msg=None): value = value.encode(self.encoding) if value in self.body: if msg is None: - msg = '%r found in body' % value + msg = "%r found in body" % value self._handlewebError(msg) def assertMatchesBody(self, pattern, msg=None, flags=0): @@ -416,11 +420,11 @@ def assertMatchesBody(self, pattern, msg=None, flags=0): pattern = pattern.encode(self.encoding) if re.search(pattern, self.body, flags) is None: if msg is None: - msg = 'No match for %r in body' % pattern + msg = "No match for %r in body" % pattern self._handlewebError(msg) -methods_with_bodies = ('POST', 'PUT', 'PATCH') +methods_with_bodies = ("POST", "PUT", "PATCH") def cleanHeaders(headers, method, body, host, port): @@ -432,34 +436,34 @@ def cleanHeaders(headers, method, body, host, port): # [This specifies the host:port of the server, not the client.] found = False for k, _v in headers: - if k.lower() == 'host': + if k.lower() == "host": found = True break if not found: if port == 80: - headers.append(('Host', host)) + headers.append(("Host", host)) else: - headers.append(('Host', '%s:%s' % (host, port))) + headers.append(("Host", "%s:%s" % (host, port))) if method in methods_with_bodies: # Stick in default type and length headers if not present found = False for k, v in headers: - if k.lower() == 'content-type': + if k.lower() == "content-type": found = True break if not found: headers.append( - ('Content-Type', 'application/x-www-form-urlencoded'), + ("Content-Type", "application/x-www-form-urlencoded"), ) - headers.append(('Content-Length', str(len(body or '')))) + headers.append(("Content-Length", str(len(body or "")))) return headers def shb(response): """Return status, headers, body the way we like from a response.""" - resp_status_line = '%s %s' % (response.status, response.reason) + resp_status_line = "%s %s" % (response.status, response.reason) return resp_status_line, response.getheaders(), response.read() @@ -489,38 +493,46 @@ def on_exception(): def _open_url_once( - url, headers=None, method='GET', body=None, - host='127.0.0.1', port=8000, http_conn=http.client.HTTPConnection, - protocol='HTTP/1.1', ssl_context=None, + url, + headers=None, + method="GET", + body=None, + host="127.0.0.1", + port=8000, + http_conn=http.client.HTTPConnection, + protocol="HTTP/1.1", + ssl_context=None, ): """Open the given HTTP resource and return status, headers, and body.""" headers = cleanHeaders(headers, method, body, host, port) # Allow http_conn to be a class or an instance - if hasattr(http_conn, 'host'): + if hasattr(http_conn, "host"): conn = http_conn else: kw = {} if ssl_context: - kw['context'] = ssl_context + kw["context"] = ssl_context conn = http_conn(interface(host), port, **kw) conn._http_vsn_str = protocol - conn._http_vsn = int(''.join([x for x in protocol if x.isdigit()])) + conn._http_vsn = int("".join([x for x in protocol if x.isdigit()])) if isinstance(url, bytes): url = url.decode() conn.putrequest( - method.upper(), url, skip_host=True, + method.upper(), + url, + skip_host=True, skip_accept_encoding=True, ) for key, value in headers: - conn.putheader(key, value.encode('Latin-1')) + conn.putheader(key, value.encode("Latin-1")) conn.endheaders() if body is not None: conn.send(body) # Handle response response = conn.getresponse() s, h, b = shb(response) - if not hasattr(http_conn, 'host'): + if not hasattr(http_conn, "host"): # We made our own conn instance. Close it. conn.close() return s, h, b @@ -550,7 +562,7 @@ def strip_netloc(url): """ parsed = urllib.parse.urlparse(url) _scheme, _netloc, path, params, query, _fragment = parsed - stripped = '', '', path, params, query, '' + stripped = "", "", path, params, query, "" return urllib.parse.urlunparse(stripped) @@ -584,6 +596,6 @@ def server_error(exc=None): return False else: ServerError.on = True - print('') - print(''.join(traceback.format_exception(*exc))) + print("") + print("".join(traceback.format_exception(*exc))) return True diff --git a/cheroot/testing.py b/cheroot/testing.py index 5457a4b1ce..211a91423b 100644 --- a/cheroot/testing.py +++ b/cheroot/testing.py @@ -15,17 +15,17 @@ EPHEMERAL_PORT = 0 NO_INTERFACE = None # Using this or '' will cause an exception -ANY_INTERFACE_IPV4 = '0.0.0.0' -ANY_INTERFACE_IPV6 = '::' +ANY_INTERFACE_IPV4 = "0.0.0.0" +ANY_INTERFACE_IPV6 = "::" config = { cheroot.wsgi.Server: { - 'bind_addr': (NO_INTERFACE, EPHEMERAL_PORT), - 'wsgi_app': None, + "bind_addr": (NO_INTERFACE, EPHEMERAL_PORT), + "wsgi_app": None, }, cheroot.server.HTTPServer: { - 'bind_addr': (NO_INTERFACE, EPHEMERAL_PORT), - 'gateway': cheroot.server.Gateway, + "bind_addr": (NO_INTERFACE, EPHEMERAL_PORT), + "gateway": cheroot.server.Gateway, }, } @@ -34,7 +34,7 @@ def cheroot_server(server_factory): # noqa: WPS210 """Set up and tear down a Cheroot server instance.""" conf = config[server_factory].copy() - bind_port = conf.pop('bind_addr')[-1] + bind_port = conf.pop("bind_addr")[-1] for interface in ANY_INTERFACE_IPV6, ANY_INTERFACE_IPV4: try: @@ -107,25 +107,31 @@ def __init__(self, server): self._http_connection = self.get_connection() def get_connection(self): - name = '{interface}:{port}'.format( + name = "{interface}:{port}".format( interface=self._interface, port=self._port, ) conn_cls = ( http.client.HTTPConnection - if self.server_instance.ssl_adapter is None else - http.client.HTTPSConnection + if self.server_instance.ssl_adapter is None + else http.client.HTTPSConnection ) return conn_cls(name) def request( - self, uri, method='GET', headers=None, http_conn=None, - protocol='HTTP/1.1', + self, + uri, + method="GET", + headers=None, + http_conn=None, + protocol="HTTP/1.1", ): return webtest.openURL( - uri, method=method, + uri, + method=method, headers=headers, - host=self._host, port=self._port, + host=self._host, + port=self._port, http_conn=http_conn or self._http_connection, protocol=protocol, ) @@ -161,9 +167,9 @@ def _get_conn_data(bind_addr): interface = webtest.interface(host) - if ':' in interface and not _probe_ipv6_sock(interface): - interface = '127.0.0.1' - if ':' in host: + if ":" in interface and not _probe_ipv6_sock(interface): + interface = "127.0.0.1" + if ":" in host: host = interface return interface, host, port diff --git a/cheroot/testing.pyi b/cheroot/testing.pyi index 4c825f9867..3b3773706f 100644 --- a/cheroot/testing.pyi +++ b/cheroot/testing.pyi @@ -3,7 +3,7 @@ from typing import Any, Iterator, Optional, TypeVar from .server import HTTPServer from .wsgi import Server -T = TypeVar('T', bound=HTTPServer) +T = TypeVar("T", bound=HTTPServer) EPHEMERAL_PORT: int NO_INTERFACE: Optional[str] diff --git a/cheroot/workers/threadpool.py b/cheroot/workers/threadpool.py index 821e4b0ba0..265c50a6a2 100644 --- a/cheroot/workers/threadpool.py +++ b/cheroot/workers/threadpool.py @@ -16,7 +16,7 @@ from jaraco.functools import pass_none -__all__ = ('WorkerThread', 'ThreadPool') +__all__ = ("WorkerThread", "ThreadPool") class TrueyZero: @@ -70,32 +70,18 @@ def __init__(self, server): self.start_time = None self.work_time = 0 self.stats = { - 'Requests': lambda s: self.requests_seen + ( - self.start_time is None - and trueyzero - or self.conn.requests_seen - ), - 'Bytes Read': lambda s: self.bytes_read + ( - self.start_time is None - and trueyzero - or self.conn.rfile.bytes_read - ), - 'Bytes Written': lambda s: self.bytes_written + ( - self.start_time is None - and trueyzero - or self.conn.wfile.bytes_written - ), - 'Work Time': lambda s: self.work_time + ( - self.start_time is None - and trueyzero - or time.time() - self.start_time - ), - 'Read Throughput': lambda s: s['Bytes Read'](s) / ( - s['Work Time'](s) or 1e-6 - ), - 'Write Throughput': lambda s: s['Bytes Written'](s) / ( - s['Work Time'](s) or 1e-6 - ), + "Requests": lambda s: self.requests_seen + + (self.start_time is None and trueyzero or self.conn.requests_seen), + "Bytes Read": lambda s: self.bytes_read + + (self.start_time is None and trueyzero or self.conn.rfile.bytes_read), + "Bytes Written": lambda s: self.bytes_written + + (self.start_time is None and trueyzero or self.conn.wfile.bytes_written), + "Work Time": lambda s: self.work_time + + (self.start_time is None and trueyzero or time.time() - self.start_time), + "Read Throughput": lambda s: s["Bytes Read"](s) + / (s["Work Time"](s) or 1e-6), + "Write Throughput": lambda s: s["Bytes Written"](s) + / (s["Work Time"](s) or 1e-6), } threading.Thread.__init__(self) @@ -113,14 +99,14 @@ def run(self): # noqa: DAR401 KeyboardInterrupt SystemExit """ - self.server.stats['Worker Threads'][self.name] = self.stats + self.server.stats["Worker Threads"][self.name] = self.stats self.ready = True try: self._process_connections_until_interrupted() except (KeyboardInterrupt, SystemExit) as interrupt_exc: interrupt_cause = interrupt_exc.__cause__ or interrupt_exc self.server.error_log( - f'Setting the server interrupt flag to {interrupt_cause !r}', + f"Setting the server interrupt flag to {interrupt_cause !r}", level=logging.DEBUG, ) self.server.interrupt = interrupt_cause @@ -129,12 +115,12 @@ def run(self): # NOTE: of the worker. It is only reachable when exceptions happen # NOTE: in the `finally` branch of the internal try/except block. self.server.error_log( - 'A fatal exception happened. Setting the server interrupt flag' - f' to {underlying_exc !r} and giving up.' - '\N{NEW LINE}\N{NEW LINE}' - 'Please, report this on the Cheroot tracker at ' - ', ' - 'providing a full reproducer with as much context and details as possible.', + "A fatal exception happened. Setting the server interrupt flag" + f" to {underlying_exc !r} and giving up." + "\N{NEW LINE}\N{NEW LINE}" + "Please, report this on the Cheroot tracker at " + ", " + "providing a full reproducer with as much context and details as possible.", level=logging.CRITICAL, traceback=True, ) @@ -158,7 +144,7 @@ def _process_connections_until_interrupted(self): return self.conn = conn - is_stats_enabled = self.server.stats['Enabled'] + is_stats_enabled = self.server.stats["Enabled"] if is_stats_enabled: self.start_time = time.time() keep_conn_open = False @@ -167,9 +153,9 @@ def _process_connections_until_interrupted(self): except ConnectionError as connection_error: keep_conn_open = False # Drop the connection cleanly self.server.error_log( - 'Got a connection error while handling a ' - f'connection from {conn.remote_addr !s}:' - f'{conn.remote_port !s} ({connection_error !s})', + "Got a connection error while handling a " + f"connection from {conn.remote_addr !s}:" + f"{conn.remote_port !s} ({connection_error !s})", level=logging.INFO, ) continue @@ -177,9 +163,9 @@ def _process_connections_until_interrupted(self): # Shutdown request keep_conn_open = False # Drop the connection cleanly self.server.error_log( - 'Got a server shutdown request while handling a ' - f'connection from {conn.remote_addr !s}:' - f'{conn.remote_port !s} ({shutdown_request !s})', + "Got a server shutdown request while handling a " + f"connection from {conn.remote_addr !s}:" + f"{conn.remote_port !s} ({shutdown_request !s})", level=logging.DEBUG, ) raise SystemExit( @@ -193,8 +179,8 @@ def _process_connections_until_interrupted(self): # NOTE: the calling code would fail to schedule processing # NOTE: of new requests. self.server.error_log( - 'Unhandled error while processing an incoming ' - f'connection {unhandled_error !r}', + "Unhandled error while processing an incoming " + f"connection {unhandled_error !r}", level=logging.ERROR, traceback=True, ) @@ -231,8 +217,12 @@ class ThreadPool: """ def __init__( - self, server, min=10, max=-1, accepted_queue_size=-1, - accepted_queue_timeout=10, + self, + server, + min=10, + max=-1, + accepted_queue_size=-1, + accepted_queue_timeout=10, ): """Initialize HTTP requests queue instance. @@ -250,21 +240,21 @@ def __init__( :raises TypeError: if the max is not an integer or inf """ if min < 1: - raise ValueError(f'min={min!s} must be > 0') + raise ValueError(f"min={min!s} must be > 0") - if max == float('inf'): + if max == float("inf"): pass elif not isinstance(max, int) or max == 0: raise TypeError( - 'Expected an integer or the infinity value for the `max` ' - f'argument but got {max!r}.', + "Expected an integer or the infinity value for the `max` " + f"argument but got {max!r}.", ) elif max < 0: - max = float('inf') + max = float("inf") if max < min: raise ValueError( - f'max={max!s} must be > min={min!s} (or infinity for no max)', + f"max={max!s} must be > min={min!s} (or infinity for no max)", ) self.server = server @@ -282,7 +272,7 @@ def start(self): :raises RuntimeError: if the pool is already started """ if self._threads: - raise RuntimeError('Threadpools can only be started once.') + raise RuntimeError("Threadpools can only be started once.") self.grow(self.min) @property @@ -317,15 +307,12 @@ def grow(self, amount): workers = [self._spawn_worker() for i in range(n_new)] for worker in workers: while not worker.ready: - time.sleep(.1) + time.sleep(0.1) self._threads.extend(workers) def _spawn_worker(self): worker = WorkerThread(self.server) - worker.name = ( - 'CP Server {worker_name!s}'. - format(worker_name=worker.name) - ) + worker.name = "CP Server {worker_name!s}".format(worker_name=worker.name) worker.start() return worker @@ -362,8 +349,8 @@ def stop(self, timeout=5): if timeout is not None and timeout < 0: timeout = None warnings.warning( - 'In the future, negative timeouts to Server.stop() ' - 'will be equivalent to a timeout of zero.', + "In the future, negative timeouts to Server.stop() " + "will be equivalent to a timeout of zero.", stacklevel=2, ) @@ -415,9 +402,7 @@ def _clear_threads(self): # threads = pop_all(self._threads) threads, self._threads[:] = self._threads[:], [] return ( - thread - for thread in threads - if thread is not threading.current_thread() + thread for thread in threads if thread is not threading.current_thread() ) @property diff --git a/cheroot/workers/threadpool.pyi b/cheroot/workers/threadpool.pyi index 201d39140b..02a09b6c90 100644 --- a/cheroot/workers/threadpool.pyi +++ b/cheroot/workers/threadpool.pyi @@ -25,7 +25,14 @@ class ThreadPool: min: Any max: Any get: Any - def __init__(self, server, min: int = ..., max: int = ..., accepted_queue_size: int = ..., accepted_queue_timeout: int = ...) -> None: ... + def __init__( + self, + server, + min: int = ..., + max: int = ..., + accepted_queue_size: int = ..., + accepted_queue_timeout: int = ..., + ) -> None: ... def start(self) -> None: ... @property def idle(self): ... diff --git a/cheroot/wsgi.py b/cheroot/wsgi.py index 1dbe10ee2c..cfed538474 100644 --- a/cheroot/wsgi.py +++ b/cheroot/wsgi.py @@ -39,10 +39,19 @@ class Server(server.HTTPServer): """The version of WSGI to produce.""" def __init__( - self, bind_addr, wsgi_app, numthreads=10, server_name=None, - max=-1, request_queue_size=5, timeout=10, shutdown_timeout=5, - accepted_queue_size=-1, accepted_queue_timeout=10, - peercreds_enabled=False, peercreds_resolve_enabled=False, + self, + bind_addr, + wsgi_app, + numthreads=10, + server_name=None, + max=-1, + request_queue_size=5, + timeout=10, + shutdown_timeout=5, + accepted_queue_size=-1, + accepted_queue_timeout=10, + peercreds_enabled=False, + peercreds_resolve_enabled=False, reuse_port=False, ): """Initialize WSGI Server instance. @@ -77,7 +86,9 @@ def __init__( self.timeout = timeout self.shutdown_timeout = shutdown_timeout self.requests = threadpool.ThreadPool( - self, min=numthreads or 1, max=max, + self, + min=numthreads or 1, + max=max, accepted_queue_size=accepted_queue_size, accepted_queue_timeout=accepted_queue_timeout, ) @@ -137,12 +148,12 @@ def respond(self): try: for chunk in filter(None, response): if not isinstance(chunk, bytes): - raise ValueError('WSGI Applications must yield bytes') + raise ValueError("WSGI Applications must yield bytes") self.write(chunk) finally: # Send headers if not already sent self.req.ensure_headers_sent() - if hasattr(response, 'close'): + if hasattr(response, "close"): response.close() def start_response(self, status, headers, exc_info=None): # noqa: WPS238 @@ -151,8 +162,7 @@ def start_response(self, status, headers, exc_info=None): # noqa: WPS238 # if and only if the exc_info argument is provided." if self.started_response and not exc_info: raise RuntimeError( - 'WSGI start_response called a second ' - 'time with no exc_info.', + "WSGI start_response called a second " "time with no exc_info.", ) self.started_response = True @@ -168,13 +178,13 @@ def start_response(self, status, headers, exc_info=None): # noqa: WPS238 for k, v in headers: if not isinstance(k, str): raise TypeError( - 'WSGI response header key %r is not of type str.' % k, + "WSGI response header key %r is not of type str." % k, ) if not isinstance(v, str): raise TypeError( - 'WSGI response header value %r is not of type str.' % v, + "WSGI response header value %r is not of type str." % v, ) - if k.lower() == 'content-length': + if k.lower() == "content-length": self.remaining_bytes_out = int(v) out_header = ntob(k), ntob(v) self.req.outheaders.append(out_header) @@ -191,8 +201,8 @@ def _encode_status(status): "Latin-1" set. """ if not isinstance(status, str): - raise TypeError('WSGI response status is not of type str.') - return status.encode('ISO-8859-1') + raise TypeError("WSGI response status is not of type str.") + return status.encode("ISO-8859-1") def write(self, chunk): """WSGI callable to write unbuffered data to the client. @@ -201,7 +211,7 @@ def write(self, chunk): data from the iterable returned by the WSGI application). """ if not self.started_response: - raise RuntimeError('WSGI write called before start_response.') + raise RuntimeError("WSGI write called before start_response.") chunklen = len(chunk) rbo = self.remaining_bytes_out @@ -209,9 +219,9 @@ def write(self, chunk): if not self.req.sent_headers: # Whew. We can send a 500 to the client. self.req.simple_response( - '500 Internal Server Error', - 'The requested resource returned more bytes than the ' - 'declared Content-Length.', + "500 Internal Server Error", + "The requested resource returned more bytes than the " + "declared Content-Length.", ) else: # Dang. We have probably already sent data. Truncate the chunk @@ -226,7 +236,7 @@ def write(self, chunk): rbo -= chunklen if rbo < 0: raise ValueError( - 'Response body exceeds the declared Content-Length.', + "Response body exceeds the declared Content-Length.", ) @@ -243,41 +253,41 @@ def get_environ(self): # set a non-standard environ entry so the WSGI app can know what # the *real* server protocol is (and what features to support). # See http://www.faqs.org/rfcs/rfc2145.html. - 'ACTUAL_SERVER_PROTOCOL': req.server.protocol, - 'PATH_INFO': bton(req.path), - 'QUERY_STRING': bton(req.qs), - 'REMOTE_ADDR': req_conn.remote_addr or '', - 'REMOTE_PORT': str(req_conn.remote_port or ''), - 'REQUEST_METHOD': bton(req.method), - 'REQUEST_URI': bton(req.uri), - 'SCRIPT_NAME': '', - 'SERVER_NAME': req.server.server_name, + "ACTUAL_SERVER_PROTOCOL": req.server.protocol, + "PATH_INFO": bton(req.path), + "QUERY_STRING": bton(req.qs), + "REMOTE_ADDR": req_conn.remote_addr or "", + "REMOTE_PORT": str(req_conn.remote_port or ""), + "REQUEST_METHOD": bton(req.method), + "REQUEST_URI": bton(req.uri), + "SCRIPT_NAME": "", + "SERVER_NAME": req.server.server_name, # Bah. "SERVER_PROTOCOL" is actually the REQUEST protocol. - 'SERVER_PROTOCOL': bton(req.request_protocol), - 'SERVER_SOFTWARE': req.server.software, - 'wsgi.errors': sys.stderr, - 'wsgi.input': req.rfile, - 'wsgi.input_terminated': bool(req.chunked_read), - 'wsgi.multiprocess': False, - 'wsgi.multithread': True, - 'wsgi.run_once': False, - 'wsgi.url_scheme': bton(req.scheme), - 'wsgi.version': self.version, + "SERVER_PROTOCOL": bton(req.request_protocol), + "SERVER_SOFTWARE": req.server.software, + "wsgi.errors": sys.stderr, + "wsgi.input": req.rfile, + "wsgi.input_terminated": bool(req.chunked_read), + "wsgi.multiprocess": False, + "wsgi.multithread": True, + "wsgi.run_once": False, + "wsgi.url_scheme": bton(req.scheme), + "wsgi.version": self.version, } if isinstance(req.server.bind_addr, str): # AF_UNIX. This isn't really allowed by WSGI, which doesn't # address unix domain sockets. But it's better than nothing. - env['SERVER_PORT'] = '' + env["SERVER_PORT"] = "" try: - env['X_REMOTE_PID'] = str(req_conn.peer_pid) - env['X_REMOTE_UID'] = str(req_conn.peer_uid) - env['X_REMOTE_GID'] = str(req_conn.peer_gid) + env["X_REMOTE_PID"] = str(req_conn.peer_pid) + env["X_REMOTE_UID"] = str(req_conn.peer_uid) + env["X_REMOTE_GID"] = str(req_conn.peer_gid) - env['X_REMOTE_USER'] = str(req_conn.peer_user) - env['X_REMOTE_GROUP'] = str(req_conn.peer_group) + env["X_REMOTE_USER"] = str(req_conn.peer_user) + env["X_REMOTE_GROUP"] = str(req_conn.peer_group) - env['REMOTE_USER'] = env['X_REMOTE_USER'] + env["REMOTE_USER"] = env["X_REMOTE_USER"] except RuntimeError: """Unable to retrieve peer creds data. @@ -285,25 +295,26 @@ def get_environ(self): unsupported socket type, or disabled. """ else: - env['SERVER_PORT'] = str(req.server.bind_addr[1]) + env["SERVER_PORT"] = str(req.server.bind_addr[1]) # Request headers env.update( ( - 'HTTP_{header_name!s}'. - format(header_name=bton(k).upper().replace('-', '_')), + "HTTP_{header_name!s}".format( + header_name=bton(k).upper().replace("-", "_") + ), bton(v), ) for k, v in req.inheaders.items() ) # CONTENT_TYPE/CONTENT_LENGTH - ct = env.pop('HTTP_CONTENT_TYPE', None) + ct = env.pop("HTTP_CONTENT_TYPE", None) if ct is not None: - env['CONTENT_TYPE'] = ct - cl = env.pop('HTTP_CONTENT_LENGTH', None) + env["CONTENT_TYPE"] = ct + cl = env.pop("HTTP_CONTENT_LENGTH", None) if cl is not None: - env['CONTENT_LENGTH'] = cl + env["CONTENT_LENGTH"] = cl if req.conn.ssl_env: env.update(req.conn.ssl_env) @@ -318,7 +329,7 @@ class Gateway_u0(Gateway_10): and values in both Python 2 and Python 3. """ - version = 'u', 0 + version = "u", 0 def get_environ(self): """Return a new environ dict targeting the given wsgi.version.""" @@ -327,15 +338,15 @@ def get_environ(self): env = dict(env_10.items()) # Request-URI - enc = env.setdefault('wsgi.url_encoding', 'utf-8') + enc = env.setdefault("wsgi.url_encoding", "utf-8") try: - env['PATH_INFO'] = req.path.decode(enc) - env['QUERY_STRING'] = req.qs.decode(enc) + env["PATH_INFO"] = req.path.decode(enc) + env["QUERY_STRING"] = req.qs.decode(enc) except UnicodeDecodeError: # Fall back to latin 1 so apps can transcode if needed. - env['wsgi.url_encoding'] = 'ISO-8859-1' - env['PATH_INFO'] = env_10['PATH_INFO'] - env['QUERY_STRING'] = env_10['QUERY_STRING'] + env["wsgi.url_encoding"] = "ISO-8859-1" + env["PATH_INFO"] = env_10["PATH_INFO"] + env["QUERY_STRING"] = env_10["QUERY_STRING"] env.update(env.items()) @@ -363,11 +374,12 @@ def __init__(self, apps): # Sort the apps by len(path), descending def by_path_len(app): return len(app[0]) + apps.sort(key=by_path_len, reverse=True) # The path_prefix strings must start, but not end, with a slash. # Use "" instead of "/". - self.apps = [(p.rstrip('/'), a) for p, a in apps] + self.apps = [(p.rstrip("/"), a) for p, a in apps] def __call__(self, environ, start_response): """Process incoming WSGI request. @@ -384,22 +396,23 @@ def __call__(self, environ, start_response): HTTP response body """ - path = environ['PATH_INFO'] or '/' + path = environ["PATH_INFO"] or "/" for p, app in self.apps: # The apps list should be sorted by length, descending. - if path.startswith('{path!s}/'.format(path=p)) or path == p: + if path.startswith("{path!s}/".format(path=p)) or path == p: environ = environ.copy() - environ['SCRIPT_NAME'] = environ.get('SCRIPT_NAME', '') + p - environ['PATH_INFO'] = path[len(p):] + environ["SCRIPT_NAME"] = environ.get("SCRIPT_NAME", "") + p + environ["PATH_INFO"] = path[len(p) :] return app(environ, start_response) start_response( - '404 Not Found', [ - ('Content-Type', 'text/plain'), - ('Content-Length', '0'), + "404 Not Found", + [ + ("Content-Type", "text/plain"), + ("Content-Length", "0"), ], ) - return [''] + return [""] # compatibility aliases diff --git a/cheroot/wsgi.pyi b/cheroot/wsgi.pyi index f96a18f928..99d88a534a 100644 --- a/cheroot/wsgi.pyi +++ b/cheroot/wsgi.pyi @@ -8,7 +8,22 @@ class Server(server.HTTPServer): timeout: Any shutdown_timeout: Any requests: Any - def __init__(self, bind_addr, wsgi_app, numthreads: int = ..., server_name: Any | None = ..., max: int = ..., request_queue_size: int = ..., timeout: int = ..., shutdown_timeout: int = ..., accepted_queue_size: int = ..., accepted_queue_timeout: int = ..., peercreds_enabled: bool = ..., peercreds_resolve_enabled: bool = ..., reuse_port: bool = ...) -> None: ... + def __init__( + self, + bind_addr, + wsgi_app, + numthreads: int = ..., + server_name: Any | None = ..., + max: int = ..., + request_queue_size: int = ..., + timeout: int = ..., + shutdown_timeout: int = ..., + accepted_queue_size: int = ..., + accepted_queue_timeout: int = ..., + peercreds_enabled: bool = ..., + peercreds_resolve_enabled: bool = ..., + reuse_port: bool = ..., + ) -> None: ... @property def numthreads(self): ... @numthreads.setter @@ -41,7 +56,6 @@ class PathInfoDispatcher: def __init__(self, apps): ... def __call__(self, environ, start_response): ... - WSGIServer = Server WSGIGateway = Gateway WSGIGateway_u0 = Gateway_u0 diff --git a/docs/conf.py b/docs/conf.py index 72d4f4dac0..9d05bf0e23 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,13 +11,13 @@ PROJECT_ROOT_DIR = Path(__file__).parents[1].resolve() IS_RELEASE_ON_RTD = ( - os.getenv('READTHEDOCS', 'False') == 'True' - and os.environ['READTHEDOCS_VERSION_TYPE'] == 'tag' + os.getenv("READTHEDOCS", "False") == "True" + and os.environ["READTHEDOCS_VERSION_TYPE"] == "tag" ) if IS_RELEASE_ON_RTD: tags: set[str] # pylint: disable-next=used-before-assignment - tags.add('is_release') # noqa: F821 + tags.add("is_release") # noqa: F821 # Make in-tree extension importable in non-tox setups/envs, like RTD. @@ -28,41 +28,40 @@ extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.autosectionlabel', # autocreate section targets for refs - 'sphinx.ext.doctest', - 'sphinx.ext.extlinks', - 'sphinx.ext.intersphinx', + "sphinx.ext.autodoc", + "sphinx.ext.autosectionlabel", # autocreate section targets for refs + "sphinx.ext.doctest", + "sphinx.ext.extlinks", + "sphinx.ext.intersphinx", # Third-party extensions: - 'jaraco.packaging.sphinx', - 'sphinx_tabs.tabs', - 'sphinxcontrib.apidoc', - 'sphinxcontrib.towncrier.ext', # provides `.. towncrier-draft-entries::` - + "jaraco.packaging.sphinx", + "sphinx_tabs.tabs", + "sphinxcontrib.apidoc", + "sphinxcontrib.towncrier.ext", # provides `.. towncrier-draft-entries::` # In-tree extensions: - 'spelling_stub_ext', # auto-loads `sphinxcontrib.spelling` if installed + "spelling_stub_ext", # auto-loads `sphinxcontrib.spelling` if installed ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = [ - 'changelog-fragments.d/**', # Towncrier-managed change notes + "changelog-fragments.d/**", # Towncrier-managed change notes ] -master_doc = 'index' +master_doc = "index" apidoc_excluded_paths = [] apidoc_extra_args = [ - '--implicit-namespaces', - '--private', # include “_private” modules + "--implicit-namespaces", + "--private", # include “_private” modules ] -apidoc_module_dir = '../cheroot' +apidoc_module_dir = "../cheroot" apidoc_module_first = False -apidoc_output_dir = 'pkg' +apidoc_output_dir = "pkg" apidoc_separate_modules = True apidoc_toc_file = None @@ -73,62 +72,58 @@ spelling_ignore_wiki_words = True spelling_show_suggestions = True spelling_word_list_filename = [ - 'spelling_wordlist.txt', + "spelling_wordlist.txt", ] -github_url = 'https://github.com' -github_repo_org = 'cherrypy' -github_repo_name = 'cheroot' -github_repo_slug = f'{github_repo_org}/{github_repo_name}' -github_repo_url = f'{github_url}/{github_repo_slug}' -cp_github_repo_url = f'{github_url}/{github_repo_org}/cherrypy' -github_sponsors_url = f'{github_url}/sponsors' +github_url = "https://github.com" +github_repo_org = "cherrypy" +github_repo_name = "cheroot" +github_repo_slug = f"{github_repo_org}/{github_repo_name}" +github_repo_url = f"{github_url}/{github_repo_slug}" +cp_github_repo_url = f"{github_url}/{github_repo_org}/cherrypy" +github_sponsors_url = f"{github_url}/sponsors" extlinks = { - 'issue': (f'{github_repo_url}/issues/%s', '#%s'), - 'pr': (f'{github_repo_url}/pull/%s', 'PR #%s'), - 'commit': (f'{github_repo_url}/commit/%s', '%s'), - 'cp-issue': (f'{cp_github_repo_url}/issues/%s', 'CherryPy #%s'), - 'cp-pr': (f'{cp_github_repo_url}/pull/%s', 'CherryPy PR #%s'), - 'gh': (f'{github_url}/%s', 'GitHub: %s'), - 'user': (f'{github_sponsors_url}/%s', '@%s'), + "issue": (f"{github_repo_url}/issues/%s", "#%s"), + "pr": (f"{github_repo_url}/pull/%s", "PR #%s"), + "commit": (f"{github_repo_url}/commit/%s", "%s"), + "cp-issue": (f"{cp_github_repo_url}/issues/%s", "CherryPy #%s"), + "cp-pr": (f"{cp_github_repo_url}/pull/%s", "CherryPy PR #%s"), + "gh": (f"{github_url}/%s", "GitHub: %s"), + "user": (f"{github_sponsors_url}/%s", "@%s"), } intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'python2': ('https://docs.python.org/2', None), + "python": ("https://docs.python.org/3", None), + "python2": ("https://docs.python.org/2", None), # Ref: https://github.com/cherrypy/cherrypy/issues/1872 - 'cherrypy': ( - 'https://docs.cherrypy.dev/en/latest', - ('https://cherrypy.rtfd.io/en/latest', None), + "cherrypy": ( + "https://docs.cherrypy.dev/en/latest", + ("https://cherrypy.rtfd.io/en/latest", None), ), - 'trustme': ('https://trustme.readthedocs.io/en/latest/', None), - 'ddt': ('https://ddt.readthedocs.io/en/latest/', None), - 'pyopenssl': ('https://www.pyopenssl.org/en/latest/', None), - 'towncrier': ('https://towncrier.rtfd.io/en/latest', None), + "trustme": ("https://trustme.readthedocs.io/en/latest/", None), + "ddt": ("https://ddt.readthedocs.io/en/latest/", None), + "pyopenssl": ("https://www.pyopenssl.org/en/latest/", None), + "towncrier": ("https://towncrier.rtfd.io/en/latest", None), } linkcheck_ignore = [ - r'http://localhost:\d+/', # local URLs - r'https://codecov\.io/gh/cherrypy/cheroot/branch/master/graph/badge\.svg', - r'https://github\.com/cherrypy/cheroot/actions', # 404 if no auth - + r"http://localhost:\d+/", # local URLs + r"https://codecov\.io/gh/cherrypy/cheroot/branch/master/graph/badge\.svg", + r"https://github\.com/cherrypy/cheroot/actions", # 404 if no auth # Too many links to GitHub so they cause # "429 Client Error: too many requests for url" # Ref: https://github.com/sphinx-doc/sphinx/issues/7388 - r'https://github\.com/cherrypy/cheroot/issues', - r'https://github\.com/cherrypy/cheroot/pull', - r'https://github\.com/cherrypy/cherrypy/issues', - r'https://github\.com/cherrypy/cherrypy/pull', - + r"https://github\.com/cherrypy/cheroot/issues", + r"https://github\.com/cherrypy/cheroot/pull", + r"https://github\.com/cherrypy/cherrypy/issues", + r"https://github\.com/cherrypy/cherrypy/pull", # Has an ephemeral anchor (line-range) but actual HTML has separate per- # line anchors. - r'https://github\.com' - r'/python/cpython/blob/c39b52f/Lib/poplib\.py#L297-L302', - r'https://github\.com' - r'/python/cpython/blob/c39b52f/Lib/poplib\.py#user-content-L297-L302', - - r'^https://matrix\.to/#', # these render fully on front-end from anchors + r"https://github\.com" r"/python/cpython/blob/c39b52f/Lib/poplib\.py#L297-L302", + r"https://github\.com" + r"/python/cpython/blob/c39b52f/Lib/poplib\.py#user-content-L297-L302", + r"^https://matrix\.to/#", # these render fully on front-end from anchors ] linkcheck_workers = 25 @@ -143,33 +138,32 @@ # NOTE: consider having a separate ignore file # Ref: https://stackoverflow.com/a/30624034/595220 nitpick_ignore = [ - ('py:const', 'socket.SO_PEERCRED'), - ('py:class', '_pyio.BufferedWriter'), - ('py:class', '_pyio.BufferedReader'), - ('py:class', 'unittest.case.TestCase'), - ('py:meth', 'cheroot.connections.ConnectionManager.get_conn'), - + ("py:const", "socket.SO_PEERCRED"), + ("py:class", "_pyio.BufferedWriter"), + ("py:class", "_pyio.BufferedReader"), + ("py:class", "unittest.case.TestCase"), + ("py:meth", "cheroot.connections.ConnectionManager.get_conn"), # Ref: https://github.com/pyca/pyopenssl/issues/1012 - ('py:class', 'pyopenssl:OpenSSL.SSL.Context'), + ("py:class", "pyopenssl:OpenSSL.SSL.Context"), ] # -- Options for towncrier_draft extension ----------------------------------- # or: 'sphinx-version', 'sphinx-release' -towncrier_draft_autoversion_mode = 'draft' +towncrier_draft_autoversion_mode = "draft" towncrier_draft_include_empty = True towncrier_draft_working_directory = PROJECT_ROOT_DIR -towncrier_draft_config_path = 'towncrier.toml' # relative to cwd +towncrier_draft_config_path = "towncrier.toml" # relative to cwd # Ref: # * https://github.com/djungelorm/sphinx-tabs/issues/26#issuecomment-422160463 -sphinx_tabs_valid_builders = ['linkcheck'] # prevent linkcheck warning +sphinx_tabs_valid_builders = ["linkcheck"] # prevent linkcheck warning # Ref: https://github.com/python-attrs/attrs/pull/571/files\ # #diff-85987f48f1258d9ee486e3191495582dR82 -default_role = 'any' +default_role = "any" -html_theme = 'furo' +html_theme = "furo" diff --git a/docs/spelling_stub_ext.py b/docs/spelling_stub_ext.py index 0eda52b3f6..45f936d2be 100644 --- a/docs/spelling_stub_ext.py +++ b/docs/spelling_stub_ext.py @@ -38,18 +38,18 @@ def _skip(self, word: str) -> bool: return False logger.debug( - 'Known version words: %r', # noqa: WPS323 + "Known version words: %r", # noqa: WPS323 known_version_words, ) logger.debug( - 'Ignoring %r because it is a known version', # noqa: WPS323 + "Ignoring %r because it is a known version", # noqa: WPS323 word, ) return True app.config.spelling_filters = [VersionFilter] - app.setup_extension('sphinxcontrib.spelling') + app.setup_extension("sphinxcontrib.spelling") class SpellingNoOpDirective(SphinxDirective): @@ -65,11 +65,11 @@ def run(self) -> List[nodes.Node]: def setup(app: Sphinx) -> None: """Initialize the extension.""" if _EnchantTokenizeFilterBase is object: - app.add_directive('spelling', SpellingNoOpDirective) + app.add_directive("spelling", SpellingNoOpDirective) else: - app.connect('config-inited', _configure_spelling_ext) + app.connect("config-inited", _configure_spelling_ext) return { - 'parallel_read_safe': True, - 'version': __version__, + "parallel_read_safe": True, + "version": __version__, }