diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4275059d..1d6d3ed1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,21 +8,20 @@ on: jobs: tox: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 strategy: - max-parallel: 5 matrix: python-version: - - "3.9" - "3.10" - "3.11" - "3.12" - "3.13" - - "pypy3.9" + - "3.14" + - "pypy3.11" steps: - - uses: actions/checkout@v2 - - uses: actions/setup-python@v5 + - uses: actions/checkout@v5 + - uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install tox diff --git a/.prospector.yml b/.prospector.yml deleted file mode 100644 index 6f06a69f..00000000 --- a/.prospector.yml +++ /dev/null @@ -1,10 +0,0 @@ -pylint: - disable: - - cyclic-import - - no-else-break - - no-else-continue - - unused-argument - - useless-object-inheritance - - options: - max-args: 6 diff --git a/.readthedocs.yaml b/.readthedocs.yaml index b4abc6dc..322b844b 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -1,34 +1,16 @@ -# .readthedocs.yaml -# Read the Docs configuration file +# Read the Docs configuration file for Sphinx projects # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details -# Required version: 2 -# Set the OS, Python version and other tools you might need build: - os: ubuntu-22.04 + os: ubuntu-24.04 tools: - python: "3.11" - # You can also specify other tool versions: - # nodejs: "19" - # rust: "1.64" - # golang: "1.19" + python: "3.14" + jobs: + install: + - pip install -U pip + - pip install --group 'docs' . -# Build documentation in the "docs/" directory with Sphinx sphinx: configuration: docs/source/conf.py - -# Optionally build your docs in additional formats such as PDF and ePub -# formats: -# - pdf -# - epub - -# Optional but recommended, declare the Python requirements required -# to build your documentation -# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html -python: - install: - - requirements: docs/source/requirements.txt - - method: pip - path: . diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b840eb63..6ecaf7d4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,11 +1,17 @@ Release History =============== -Unreleased ----------- - -- +1.3.0 (2025-06-XX) +------------------ +- Require h11>=0.16 dependency. +- Fix "Upgrade" header value to match RFC. +- Add reason "Switching Protocols" to handshake response. +- Add docs for `wsproto.Connection` +- Add support for Python 3.12, 3.13, and 3.14. +- Drop support for Python 3.7, 3.8, and 3.9. +- Improve Python typing, specifically bytes vs. bytearray. +- Various linting, styling, and packaging improvements. 1.2.0 (2022-08-23) ------------------ diff --git a/MANIFEST.in b/MANIFEST.in index e2ea2429..7ac1eaa1 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,8 +1,12 @@ graft src/wsproto graft example graft docs -graft test +graft tests graft bench + prune docs/build -include README.rst LICENSE CHANGELOG.rst tox.ini .readthedocs.yaml + +include README.rst LICENSE CHANGELOG.rst pyproject.toml + global-exclude *.pyc *.pyo *.swo *.swp *.map *.yml *.DS_Store .coverage +exclude .readthedocs.yaml diff --git a/bench/connection.py b/bench/connection.py index 6585944a..3979ab5f 100644 --- a/bench/connection.py +++ b/bench/connection.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import random import time from typing import List diff --git a/docs/source/conf.py b/docs/source/conf.py index ed76baf9..dd81b7c3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -10,16 +10,18 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # +from __future__ import annotations + import os -import sys import re +import sys -sys.path.insert(0, os.path.abspath('../..')) +sys.path.insert(0, os.path.abspath("../..")) PROJECT_ROOT = os.path.dirname(__file__) # Get the version version_regex = r'__version__ = ["\']([^"\']*)["\']' -with open(os.path.join(PROJECT_ROOT, '../../', 'src/wsproto/__init__.py')) as file_: +with open(os.path.join(PROJECT_ROOT, "../../", "src/wsproto/__init__.py")) as file_: text = file_.read() match = re.search(version_regex, text) version = match.group(1) @@ -27,9 +29,9 @@ # -- Project information ----------------------------------------------------- -project = 'wsproto' -copyright = '2020, Benno Rice' -author = 'Benno Rice' +project = "wsproto" +copyright = "2020, Benno Rice" +author = "Benno Rice" release = version # -- General configuration --------------------------------------------------- @@ -38,13 +40,13 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.intersphinx', - 'sphinx.ext.viewcode', + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "sphinx.ext.viewcode", ] # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -53,10 +55,10 @@ # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = { - 'python': ('https://docs.python.org/', None), + "python": ("https://docs.python.org/", None), } -master_doc = 'index' +master_doc = "index" # -- Options for HTML output ------------------------------------------------- @@ -64,9 +66,9 @@ # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'sphinx_rtd_theme' +html_theme = "default" # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] diff --git a/docs/source/requirements.txt b/docs/source/requirements.txt deleted file mode 100644 index cbf1e365..00000000 --- a/docs/source/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -sphinx -sphinx-rtd-theme diff --git a/example/synchronous_client.py b/example/synchronous_client.py index cb507035..5a7e812b 100644 --- a/example/synchronous_client.py +++ b/example/synchronous_client.py @@ -3,6 +3,7 @@ response. This is a poor implementation of a client. It is only intended to demonstrate how to use wsproto. """ +from __future__ import annotations import socket import sys @@ -28,7 +29,7 @@ def main() -> None: host = sys.argv[1] port = int(sys.argv[2]) except (IndexError, ValueError): - print("Usage: {} ".format(sys.argv[0])) + print(f"Usage: {sys.argv[0]} ") sys.exit(1) try: @@ -47,7 +48,6 @@ def wsproto_demo(host: str, port: int) -> None: 3) Send ping and display pong 4) Negotiate WebSocket closing handshake """ - # 0) Open TCP connection print(f"Connecting to {host}:{port}") conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -89,7 +89,7 @@ def wsproto_demo(host: str, port: int) -> None: def net_send(out_data: bytes, conn: socket.socket) -> None: """Write pending data from websocket to network.""" - print("Sending {} bytes".format(len(out_data))) + print(f"Sending {len(out_data)} bytes") conn.send(out_data) @@ -102,7 +102,7 @@ def net_recv(ws: WSConnection, conn: socket.socket) -> None: print("Received 0 bytes (connection closed)") ws.receive_data(None) else: - print("Received {} bytes".format(len(in_data))) + print(f"Received {len(in_data)} bytes") ws.receive_data(in_data) diff --git a/example/synchronous_server.py b/example/synchronous_server.py index 60d52408..f2d8a634 100644 --- a/example/synchronous_server.py +++ b/example/synchronous_server.py @@ -4,6 +4,7 @@ implementation of a server! It is only intended to demonstrate how to use wsproto. """ +from __future__ import annotations import socket import sys @@ -28,7 +29,7 @@ def main() -> None: ip = sys.argv[1] port = int(sys.argv[2]) except (IndexError, ValueError): - print("Usage: {} ".format(sys.argv[0])) + print(f"Usage: {sys.argv[0]} ") sys.exit(1) server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) @@ -40,7 +41,7 @@ def main() -> None: while True: print("Waiting for connection...") (stream, addr) = server.accept() - print("Client connected: {}:{}".format(addr[0], addr[1])) + print(f"Client connected: {addr[0]}:{addr[1]}") handle_connection(stream) stream.shutdown(socket.SHUT_WR) stream.close() @@ -67,7 +68,7 @@ def handle_connection(stream: socket.socket) -> None: while running: # 1) Read data from network in_data = stream.recv(RECEIVE_BYTES) - print("Received {} bytes".format(len(in_data))) + print(f"Received {len(in_data)} bytes") ws.receive_data(in_data) # 2) Get new events and handle them @@ -80,9 +81,7 @@ def handle_connection(stream: socket.socket) -> None: elif isinstance(event, CloseConnection): # Print log message and break out print( - "Connection closed: code={} reason={}".format( - event.code, event.reason - ) + f"Connection closed: code={event.code} reason={event.reason}", ) out_data += ws.send(event.response()) running = False @@ -100,7 +99,7 @@ def handle_connection(stream: socket.socket) -> None: print(f"Unknown event: {event!r}") # 4) Send data from wsproto to network - print("Sending {} bytes".format(len(out_data))) + print(f"Sending {len(out_data)} bytes") stream.send(out_data) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..7e94de01 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,232 @@ +# https://packaging.python.org/en/latest/guides/writing-pyproject-toml/ +# https://packaging.python.org/en/latest/specifications/pyproject-toml/ + +[build-system] +requires = ["setuptools>=77"] +build-backend = "setuptools.build_meta" + +[project] +name = "wsproto" +description = "Pure-Python WebSocket protocol implementation" +readme = { file = "README.rst", content-type = "text/x-rst" } +license-files = [ "LICENSE" ] + +authors = [ + { name = "Benno Rice", email = "benno@jeamland.net" } +] +maintainers = [ + { name = "Thomas Kriechbaumer", email = "thomas@kriechbaumer.name" }, +] + +requires-python = ">=3.10" +dependencies = [ + "h11>=0.16.0,<1", +] +dynamic = ["version"] + +# For a list of valid classifiers, see https://pypi.org/classifiers/ +classifiers = [ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] + +[project.urls] +"Homepage" = "https://github.com/python-hyper/wsproto/" +"Bug Reports" = "https://github.com/python-hyper/wsproto/issues" +"Source" = "https://github.com/python-hyper/wsproto/" +"Documentation" = "https://python-hyper.org/" + +[dependency-groups] +dev = [ + { include-group = "testing" }, + { include-group = "linting" }, + { include-group = "packaging" }, + { include-group = "docs" }, +] + +testing = [ + "pytest>=8.3.3,<9", + "pytest-cov>=6.0.0,<7", + "pytest-xdist>=3.6.1,<4", + "hypothesis>=6.119.4,<7", +] + +linting = [ + "ruff>=0.14.4,<1", + "mypy>=1.18.2,<2", + "typing_extensions>=4.15.0", +] + +packaging = [ + "check-manifest==0.51", + "readme-renderer==44.0", + "build>=1.3.0,<2", + "twine>=6.2.0,<7", + "wheel>=0.45.1,<1", +] + +docs = [ + "sphinx>=8.2.3,<9", +] + +[tool.setuptools.packages.find] +where = [ "src" ] + +[tool.setuptools.package-data] +wsproto = [ "py.typed" ] + +[tool.setuptools.dynamic] +version = { attr = "wsproto.__version__" } + +[tool.ruff] +line-length = 150 +target-version = "py39" +format.preview = true +format.docstring-code-line-length = 100 +format.docstring-code-format = true +lint.select = [ + "ALL", +] +lint.ignore = [ + "PYI034", # PEP 673 not yet available in Python 3.9 - only in 3.11+ + "ANN001", # args with typing.Any + "ANN002", # args with typing.Any + "ANN003", # kwargs with typing.Any + "ANN401", # kwargs with typing.Any + "SLF001", # implementation detail + "CPY", # not required + "D101", # docs readability + "D102", # docs readability + "D103", # docs readability + "D105", # docs readability + "D107", # docs readability + "D200", # docs readability + "D205", # docs readability + "D205", # docs readability + "D203", # docs readability + "D212", # docs readability + "D400", # docs readability + "D401", # docs readability + "D415", # docs readability + "PLR2004", # readability + "SIM108", # readability + "RUF012", # readability + "FBT001", # readability + "FBT002", # readability + "PGH003", # readability + "PGH004", # readability + "FIX001", # readability + "FIX002", # readability + "TD001", # readability + "TD002", # readability + "TD003", # readability + "S101", # readability + "PD901", # readability + "ERA001", # readability + "ARG001", # readability + "ARG002", # readability + "PLR0913", # readability + "FBT003", # readability +] +lint.isort.required-imports = [ "from __future__ import annotations" ] + +[tool.mypy] +warn_unused_configs = true +show_error_codes = true +strict = true + +[[tool.mypy.overrides]] +module = [ + "h11", + "h11._headers", +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +testpaths = [ "tests" ] + +[tool.coverage.run] +branch = true +source = [ "wsproto" ] + +[tool.coverage.report] +show_missing = true +exclude_lines = [ + "pragma: no cover", + "raise NotImplementedError()", + 'assert False, "Should not be reachable"', + # .*:.* # Python \d.* + # .*:.* # Platform-specific: +] + +[tool.coverage.paths] +source = [ + "src/", + ".tox/**/site-packages/", +] + +[tool.tox] +min_version = "4.23.2" +env_list = [ "py310", "py311", "py312", "py313", "py314", "pypy311", "lint", "docs", "packaging" ] + +[tool.tox.gh-actions] +python = """ + 3.10: py310, lint, packaging + 3.11: py311 + 3.12: py312 + 3.13: py313 + 3.14: py314, docs + pypy3.11: pypy311 +""" + +[tool.tox.env_run_base] +dependency_groups = ["testing"] +commands = [ + ["python", "-bb", "-m", "pytest", "--cov-report=xml", "--cov-report=term", "--cov=wsproto", { replace = "posargs", extend = true }] +] + +[tool.tox.env.lint] +dependency_groups = ["linting"] +commands = [ + # ["ruff", "check", "src/"], + ["mypy", "src/"], +] + +[tool.tox.env.docs] +dependency_groups = ["docs"] +allowlist_externals = ["make"] +changedir = "{toxinidir}/docs" +commands = [ + ["make", "clean"], + ["make", "html"], +] + +[tool.tox.env.packaging] +base_python = ["python3.10"] +dependency_groups = ["packaging"] +allowlist_externals = ["rm"] +commands = [ + ["rm", "-rf", "dist/"], + ["check-manifest"], + ["python", "-m", "build", "--outdir", "dist/"], + ["twine", "check", "dist/*"], +] + +[tool.tox.env.publish] +base_python = ["python3.10"] +dependency_groups = ["packaging"] +commands = [ + ["python", "-m", "build", "--outdir", "dist/"], + ["twine", "check", "dist/*"], + ["twine", "upload", "dist/*"], +] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index b0fef3c1..00000000 --- a/setup.cfg +++ /dev/null @@ -1,41 +0,0 @@ -[tool:pytest] -testpaths = test - -[coverage:run] -branch = True -source = wsproto - -[coverage:report] -show_missing = True -exclude_lines = - pragma: no cover - raise NotImplementedError() - -[coverage:paths] -source = - src - .tox/*/site-packages - -[flake8] -max-line-length = 120 -max-complexity = 15 -ignore = E203,W503,W504 - -[isort] -combine_as_imports=True -force_grid_wrap=0 -include_trailing_comma=True -known_first_party=wsproto, test -known_third_party=h11, pytest -line_length=88 -multi_line_output=3 -no_lines_before=LOCALFOLDER -order_by_type=False - -[mypy] -strict = true -warn_unused_configs = true -show_error_codes = true - -[mypy-h11.*] -ignore_missing_imports = True diff --git a/setup.py b/setup.py deleted file mode 100644 index 1a774f7e..00000000 --- a/setup.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python3 - -import os -import re - -from setuptools import setup, find_packages - -PROJECT_ROOT = os.path.dirname(__file__) - -with open(os.path.join(PROJECT_ROOT, 'README.rst')) as file_: - long_description = file_.read() - -version_regex = r'__version__ = ["\']([^"\']*)["\']' -with open(os.path.join(PROJECT_ROOT, 'src/wsproto/__init__.py')) as file_: - text = file_.read() - match = re.search(version_regex, text) - if match: - version = match.group(1) - else: - raise RuntimeError("No version number found!") - -setup( - name='wsproto', - version=version, - description='WebSockets state-machine based protocol implementation', - long_description=long_description, - long_description_content_type='text/x-rst', - author='Benno Rice', - author_email='benno@jeamland.net', - url='https://github.com/python-hyper/wsproto/', - packages=find_packages(where="src"), - package_data={'wsproto': ['py.typed']}, - package_dir={'': 'src'}, - python_requires='>=3.7.0', - license='MIT License', - classifiers=[ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy', - ], - install_requires=[ - 'h11>=0.16.0,<1', - ], -) diff --git a/src/wsproto/__init__.py b/src/wsproto/__init__.py index 60d04f2d..dbc1e23f 100644 --- a/src/wsproto/__init__.py +++ b/src/wsproto/__init__.py @@ -4,13 +4,18 @@ A WebSocket implementation. """ +from __future__ import annotations -from typing import Generator, Optional, Union +from typing import TYPE_CHECKING from .connection import Connection, ConnectionState, ConnectionType -from .events import Event from .handshake import H11Handshake -from .typing import Headers + +if TYPE_CHECKING: + from collections.abc import Generator + + from .events import Event + from .typing import Headers __version__ = "1.2.0+dev" @@ -29,7 +34,7 @@ def __init__(self, connection_type: ConnectionType) -> None: """ self.client = connection_type is ConnectionType.CLIENT self.handshake = H11Handshake(connection_type) - self.connection: Optional[Connection] = None + self.connection: Connection | None = None @property def state(self) -> ConnectionState: @@ -42,7 +47,7 @@ def state(self) -> ConnectionState: return self.connection.state def initiate_upgrade_connection( - self, headers: Headers, path: Union[bytes, str] + self, headers: Headers, path: bytes | str, ) -> None: self.handshake.initiate_upgrade_connection(headers, path) @@ -65,7 +70,7 @@ def send(self, event: Event) -> bytes: data += self.connection.send(event) return data - def receive_data(self, data: Optional[bytes]) -> None: + def receive_data(self, data: bytes | None) -> None: """ Feed network data into the connection instance. diff --git a/src/wsproto/connection.py b/src/wsproto/connection.py index 5b07b396..84751ab2 100644 --- a/src/wsproto/connection.py +++ b/src/wsproto/connection.py @@ -4,10 +4,11 @@ An implementation of a WebSocket connection. """ +from __future__ import annotations from collections import deque from enum import Enum -from typing import Deque, Generator, List, Optional +from typing import TYPE_CHECKING from .events import ( BytesMessage, @@ -18,10 +19,14 @@ Pong, TextMessage, ) -from .extensions import Extension from .frame_protocol import CloseReason, FrameProtocol, Opcode, ParseFailed from .utilities import LocalProtocolError +if TYPE_CHECKING: + from collections.abc import Generator + + from .extensions import Extension + class ConnectionState(Enum): """ @@ -68,7 +73,7 @@ class Connection: def __init__( self, connection_type: ConnectionType, - extensions: Optional[List[Extension]] = None, + extensions: list[Extension] | None = None, trailing_data: bytes = b"", ) -> None: """ @@ -82,7 +87,7 @@ def __init__( processed. """ self.client = connection_type is ConnectionType.CLIENT - self._events: Deque[Event] = deque() + self._events: deque[Event] = deque() self._proto = FrameProtocol(self.client, extensions or []) self._state = ConnectionState.OPEN self.receive_data(trailing_data) @@ -109,12 +114,13 @@ def send(self, event: Event) -> bytes: else: self._state = ConnectionState.LOCAL_CLOSING else: + msg = f"Event {event} cannot be sent in state {self.state}." raise LocalProtocolError( - f"Event {event} cannot be sent in state {self.state}." + msg, ) return data - def receive_data(self, data: Optional[bytes]) -> None: + def receive_data(self, data: bytes | None) -> None: """ Pass some received data to the connection for handling. @@ -124,7 +130,6 @@ def receive_data(self, data: Optional[bytes]) -> None: :param data: The data received from the remote peer on the network. :type data: ``bytes`` """ - if data is None: # "If _The WebSocket Connection is Closed_ and no Close control # frame was received by the endpoint (such as could occur if the @@ -137,7 +142,8 @@ def receive_data(self, data: Optional[bytes]) -> None: if self.state in (ConnectionState.OPEN, ConnectionState.LOCAL_CLOSING): self._proto.receive_bytes(data) elif self.state is ConnectionState.CLOSED: - raise LocalProtocolError("Connection already closed.") + msg = "Connection already closed." + raise LocalProtocolError(msg) else: pass # pragma: no cover @@ -154,12 +160,14 @@ def events(self) -> Generator[Event, None, None]: try: for frame in self._proto.received_frames(): if frame.opcode is Opcode.PING: - assert frame.frame_finished and frame.message_finished + assert frame.frame_finished + assert frame.message_finished assert isinstance(frame.payload, (bytes, bytearray)) yield Ping(payload=frame.payload) elif frame.opcode is Opcode.PONG: - assert frame.frame_finished and frame.message_finished + assert frame.frame_finished + assert frame.message_finished assert isinstance(frame.payload, (bytes, bytearray)) yield Pong(payload=frame.payload) @@ -183,7 +191,7 @@ def events(self) -> Generator[Event, None, None]: elif frame.opcode is Opcode.BINARY: assert isinstance(frame.payload, (bytes, bytearray)) yield BytesMessage( - data=frame.payload, + data=bytearray(frame.payload), frame_finished=frame.frame_finished, message_finished=frame.message_finished, ) diff --git a/src/wsproto/events.py b/src/wsproto/events.py index 3199f700..4faf6e4c 100644 --- a/src/wsproto/events.py +++ b/src/wsproto/events.py @@ -4,13 +4,17 @@ Events that result from processing data on a WebSocket connection. """ +from __future__ import annotations from abc import ABC from dataclasses import dataclass, field -from typing import Generic, List, Optional, Sequence, TypeVar, Union +from typing import TYPE_CHECKING, Generic, TypeVar -from .extensions import Extension -from .typing import Headers +if TYPE_CHECKING: + from collections.abc import Sequence + + from .extensions import Extension + from .typing import Headers class Event(ABC): @@ -23,7 +27,8 @@ class Event(ABC): @dataclass(frozen=True) class Request(Event): - """The beginning of a Websocket connection, the HTTP Upgrade request + """ + The beginning of a Websocket connection, the HTTP Upgrade request This event is fired when a SERVER connection receives a WebSocket handshake request (HTTP with upgrade header). @@ -55,14 +60,15 @@ class Request(Event): host: str target: str - extensions: Union[Sequence[Extension], Sequence[str]] = field(default_factory=list) + extensions: Sequence[Extension] | Sequence[str] = field(default_factory=list) extra_headers: Headers = field(default_factory=list) - subprotocols: List[str] = field(default_factory=list) + subprotocols: list[str] = field(default_factory=list) @dataclass(frozen=True) class AcceptConnection(Event): - """The acceptance of a Websocket upgrade request. + """ + The acceptance of a Websocket upgrade request. This event is fired when a CLIENT receives an acceptance response from a server. It is also used to accept an upgrade request when @@ -81,14 +87,15 @@ class AcceptConnection(Event): """ - subprotocol: Optional[str] = None - extensions: List[Extension] = field(default_factory=list) + subprotocol: str | None = None + extensions: list[Extension] = field(default_factory=list) extra_headers: Headers = field(default_factory=list) @dataclass(frozen=True) class RejectConnection(Event): - """The rejection of a Websocket upgrade request, the HTTP response. + """ + The rejection of a Websocket upgrade request, the HTTP response. The ``RejectConnection`` event sends the appropriate HTTP headers to communicate to the peer that the handshake has been rejected. You may also @@ -132,7 +139,8 @@ class RejectConnection(Event): @dataclass(frozen=True) class RejectData(Event): - """The rejection HTTP response body. + """ + The rejection HTTP response body. The caller may send multiple ``RejectData`` events. The final event should have the ``body_finished`` attribute set to ``True``. @@ -155,7 +163,8 @@ class RejectData(Event): @dataclass(frozen=True) class CloseConnection(Event): - """The end of a Websocket connection, represents a closure frame. + """ + The end of a Websocket connection, represents a closure frame. **wsproto does not automatically send a response to a close event.** To comply with the RFC you MUST send a close event back to the remote WebSocket @@ -177,19 +186,20 @@ class CloseConnection(Event): """ code: int - reason: Optional[str] = None + reason: str | None = None - def response(self) -> "CloseConnection": + def response(self) -> CloseConnection: """Generate an RFC-compliant close frame to send back to the peer.""" return CloseConnection(code=self.code, reason=self.reason) -T = TypeVar("T", bytes, str) +T = TypeVar("T", bytes, bytearray, str) @dataclass(frozen=True) class Message(Event, Generic[T]): - """The websocket data message. + """ + The websocket data message. Fields: @@ -220,7 +230,8 @@ class Message(Event, Generic[T]): @dataclass(frozen=True) class TextMessage(Message[str]): # pylint: disable=unsubscriptable-object - """This event is fired when a data frame with TEXT payload is received. + """ + Fired when a data frame with TEXT payload is received. Fields: @@ -229,35 +240,33 @@ class TextMessage(Message[str]): # pylint: disable=unsubscriptable-object The message data as string, This only represents a single chunk of data and not a full WebSocket message. You need to buffer and reassemble these chunks to get the full message. - """ - # https://github.com/python/mypy/issues/5744 data: str @dataclass(frozen=True) -class BytesMessage(Message[bytes]): # pylint: disable=unsubscriptable-object - """This event is fired when a data frame with BINARY payload is - received. +class BytesMessage(Message[bytearray]): # pylint: disable=unsubscriptable-object + """ + Fired when a data frame with BINARY payload is received. Fields: .. attribute:: data - The message data as byte string, can be decoded as UTF-8 for + The message data as bytearray, can be decoded as UTF-8 for TEXT messages. This only represents a single chunk of data and not a full WebSocket message. You need to buffer and reassemble these chunks to get the full message. """ - # https://github.com/python/mypy/issues/5744 - data: bytes + data: bytearray @dataclass(frozen=True) class Ping(Event): - """The Ping event can be sent to trigger a ping frame and is fired + """ + The Ping event can be sent to trigger a ping frame and is fired when a Ping is received. **wsproto does not automatically send a pong response to a ping event.** To @@ -273,14 +282,15 @@ class Ping(Event): payload: bytes = b"" - def response(self) -> "Pong": + def response(self) -> Pong: """Generate an RFC-compliant :class:`Pong` response to this ping.""" return Pong(payload=self.payload) @dataclass(frozen=True) class Pong(Event): - """The Pong event is fired when a Pong is received. + """ + The Pong event is fired when a Pong is received. Fields: diff --git a/src/wsproto/extensions.py b/src/wsproto/extensions.py index 4f6f4ee3..b8a93f00 100644 --- a/src/wsproto/extensions.py +++ b/src/wsproto/extensions.py @@ -4,10 +4,11 @@ WebSocket extensions. """ +from __future__ import annotations import zlib from abc import ABC, abstractmethod -from typing import Optional, Tuple, Union +from typing import Optional from .frame_protocol import CloseReason, FrameDecoder, FrameProtocol, Opcode, RsvBits @@ -19,10 +20,10 @@ def enabled(self) -> bool: return False @abstractmethod - def offer(self) -> Union[bool, str]: + def offer(self) -> bool | str: pass - def accept(self, offer: str) -> Optional[Union[bool, str]]: + def accept(self, offer: str) -> bool | str | None: pass def finalize(self, offer: str) -> None: @@ -30,31 +31,31 @@ def finalize(self, offer: str) -> None: def frame_inbound_header( self, - proto: Union[FrameDecoder, FrameProtocol], + proto: FrameDecoder | FrameProtocol, opcode: Opcode, rsv: RsvBits, payload_length: int, - ) -> Union[CloseReason, RsvBits]: + ) -> CloseReason | RsvBits: return RsvBits(False, False, False) def frame_inbound_payload_data( - self, proto: Union[FrameDecoder, FrameProtocol], data: bytes - ) -> Union[bytes, CloseReason]: + self, proto: FrameDecoder | FrameProtocol, data: bytes, + ) -> bytes | CloseReason: return data def frame_inbound_complete( - self, proto: Union[FrameDecoder, FrameProtocol], fin: bool - ) -> Union[bytes, CloseReason, None]: + self, proto: FrameDecoder | FrameProtocol, fin: bool, + ) -> bytes | CloseReason | None: pass def frame_outbound( self, - proto: Union[FrameDecoder, FrameProtocol], + proto: FrameDecoder | FrameProtocol, opcode: Opcode, rsv: RsvBits, data: bytes, fin: bool, - ) -> Tuple[RsvBits, bytes]: + ) -> tuple[RsvBits, bytes]: return (rsv, data) @@ -67,9 +68,9 @@ class PerMessageDeflate(Extension): def __init__( self, client_no_context_takeover: bool = False, - client_max_window_bits: Optional[int] = None, + client_max_window_bits: int | None = None, server_no_context_takeover: bool = False, - server_max_window_bits: Optional[int] = None, + server_max_window_bits: int | None = None, ) -> None: self.client_no_context_takeover = client_no_context_takeover self.server_no_context_takeover = server_no_context_takeover @@ -83,11 +84,11 @@ def __init__( self._compressor: Optional[zlib._Compress] = None # noqa self._decompressor: Optional[zlib._Decompress] = None # noqa # This refers to the current frame - self._inbound_is_compressible: Optional[bool] = None + self._inbound_is_compressible: bool | None = None # This refers to the ongoing message (which might span multiple # frames). Only the first frame in a fragmented message is flagged for # compression, so this carries that bit forward. - self._inbound_compressed: Optional[bool] = None + self._inbound_compressed: bool | None = None self._enabled = False @@ -98,7 +99,8 @@ def client_max_window_bits(self) -> int: @client_max_window_bits.setter def client_max_window_bits(self, value: int) -> None: if value < 9 or value > 15: - raise ValueError("Window size must be between 9 and 15 inclusive") + msg = "Window size must be between 9 and 15 inclusive" + raise ValueError(msg) self._client_max_window_bits = value @property @@ -108,7 +110,8 @@ def server_max_window_bits(self) -> int: @server_max_window_bits.setter def server_max_window_bits(self, value: int) -> None: if value < 9 or value > 15: - raise ValueError("Window size must be between 9 and 15 inclusive") + msg = "Window size must be between 9 and 15 inclusive" + raise ValueError(msg) self._server_max_window_bits = value def _compressible_opcode(self, opcode: Opcode) -> bool: @@ -117,10 +120,10 @@ def _compressible_opcode(self, opcode: Opcode) -> bool: def enabled(self) -> bool: return self._enabled - def offer(self) -> Union[bool, str]: + def offer(self) -> bool | str: parameters = [ - "client_max_window_bits=%d" % self.client_max_window_bits, - "server_max_window_bits=%d" % self.server_max_window_bits, + f"client_max_window_bits={self.client_max_window_bits}", + f"server_max_window_bits={self.server_max_window_bits}", ] if self.client_no_context_takeover: @@ -144,7 +147,7 @@ def finalize(self, offer: str) -> None: self._enabled = True - def _parse_params(self, params: str) -> Tuple[Optional[int], Optional[int]]: + def _parse_params(self, params: str) -> tuple[int | None, int | None]: client_max_window_bits = None server_max_window_bits = None @@ -167,7 +170,7 @@ def _parse_params(self, params: str) -> Tuple[Optional[int], Optional[int]]: return client_max_window_bits, server_max_window_bits - def accept(self, offer: str) -> Union[bool, None, str]: + def accept(self, offer: str) -> bool | None | str: client_max_window_bits, server_max_window_bits = self._parse_params(offer) parameters = [] @@ -178,10 +181,10 @@ def accept(self, offer: str) -> Union[bool, None, str]: parameters.append("server_no_context_takeover") try: if client_max_window_bits is not None: - parameters.append("client_max_window_bits=%d" % client_max_window_bits) + parameters.append(f"client_max_window_bits={client_max_window_bits}") self.client_max_window_bits = client_max_window_bits if server_max_window_bits is not None: - parameters.append("server_max_window_bits=%d" % server_max_window_bits) + parameters.append(f"server_max_window_bits={server_max_window_bits}") self.server_max_window_bits = server_max_window_bits except ValueError: return None @@ -191,11 +194,11 @@ def accept(self, offer: str) -> Union[bool, None, str]: def frame_inbound_header( self, - proto: Union[FrameDecoder, FrameProtocol], + proto: FrameDecoder | FrameProtocol, opcode: Opcode, rsv: RsvBits, payload_length: int, - ) -> Union[CloseReason, RsvBits]: + ) -> CloseReason | RsvBits: if rsv.rsv1 and opcode.iscontrol(): return CloseReason.PROTOCOL_ERROR if rsv.rsv1 and opcode is Opcode.CONTINUATION: @@ -217,8 +220,8 @@ def frame_inbound_header( return RsvBits(True, False, False) def frame_inbound_payload_data( - self, proto: Union[FrameDecoder, FrameProtocol], data: bytes - ) -> Union[bytes, CloseReason]: + self, proto: FrameDecoder | FrameProtocol, data: bytes, + ) -> bytes | CloseReason: if not self._inbound_compressed or not self._inbound_is_compressible: return data assert self._decompressor is not None @@ -229,8 +232,8 @@ def frame_inbound_payload_data( return CloseReason.INVALID_FRAME_PAYLOAD_DATA def frame_inbound_complete( - self, proto: Union[FrameDecoder, FrameProtocol], fin: bool - ) -> Union[bytes, CloseReason, None]: + self, proto: FrameDecoder | FrameProtocol, fin: bool, + ) -> bytes | CloseReason | None: if not fin: return None if not self._inbound_is_compressible: @@ -261,17 +264,17 @@ def frame_inbound_complete( def frame_outbound( self, - proto: Union[FrameDecoder, FrameProtocol], + proto: FrameDecoder | FrameProtocol, opcode: Opcode, rsv: RsvBits, data: bytes, fin: bool, - ) -> Tuple[RsvBits, bytes]: + ) -> tuple[RsvBits, bytes]: if not self._compressible_opcode(opcode): return (rsv, data) if opcode is not Opcode.CONTINUATION: - rsv = RsvBits(True, *rsv[1:]) + rsv = RsvBits(True, rsv[1], rsv[2]) if self._compressor is None: assert opcode is not Opcode.CONTINUATION @@ -280,7 +283,7 @@ def frame_outbound( else: bits = self.server_max_window_bits self._compressor = zlib.compressobj( - zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -int(bits) + zlib.Z_DEFAULT_COMPRESSION, zlib.DEFLATED, -int(bits), ) data = self._compressor.compress(bytes(data)) @@ -300,10 +303,10 @@ def frame_outbound( return (rsv, data) def __repr__(self) -> str: - descr = ["client_max_window_bits=%d" % self.client_max_window_bits] + descr = [f"client_max_window_bits={self.client_max_window_bits}"] if self.client_no_context_takeover: descr.append("client_no_context_takeover") - descr.append("server_max_window_bits=%d" % self.server_max_window_bits) + descr.append(f"server_max_window_bits={self.server_max_window_bits}") if self.server_no_context_takeover: descr.append("server_no_context_takeover") diff --git a/src/wsproto/frame_protocol.py b/src/wsproto/frame_protocol.py index d13a769e..288276d5 100644 --- a/src/wsproto/frame_protocol.py +++ b/src/wsproto/frame_protocol.py @@ -4,14 +4,18 @@ WebSocket frame protocol implementation. """ +from __future__ import annotations +import contextlib import os import struct -from codecs import getincrementaldecoder, IncrementalDecoder +from codecs import IncrementalDecoder, getincrementaldecoder from enum import IntEnum -from typing import Generator, List, NamedTuple, Optional, Tuple, TYPE_CHECKING, Union +from typing import TYPE_CHECKING, NamedTuple if TYPE_CHECKING: + from collections.abc import Generator + from .extensions import Extension # pragma: no cover @@ -19,12 +23,13 @@ class XorMaskerSimple: - def __init__(self, masking_key: bytes) -> None: + def __init__(self, masking_key: bytearray | bytes) -> None: self._masking_key = masking_key - def process(self, data: bytes) -> bytes: + def process(self, data: bytearray) -> bytearray: + data = bytearray(data) if data: - data_array = bytearray(data) + data_array = data a, b, c, d = (_XOR_TABLE[n] for n in self._masking_key) data_array[::4] = data_array[::4].translate(a) data_array[1::4] = data_array[1::4].translate(b) @@ -38,12 +43,12 @@ def process(self, data: bytes) -> bytes: self._masking_key[key_rotation:] + self._masking_key[:key_rotation] ) - return bytes(data_array) + return data_array return data class XorMaskerNull: - def process(self, data: bytes) -> bytes: + def process(self, data: bytearray) -> bytearray: return data @@ -207,7 +212,7 @@ class CloseReason(IntEnum): class ParseFailed(Exception): def __init__( - self, msg: str, code: CloseReason = CloseReason.PROTOCOL_ERROR + self, msg: str, code: CloseReason = CloseReason.PROTOCOL_ERROR, ) -> None: super().__init__(msg) self.code = code @@ -224,12 +229,12 @@ class Header(NamedTuple): rsv: RsvBits opcode: Opcode payload_len: int - masking_key: Optional[bytes] + masking_key: bytes | None class Frame(NamedTuple): opcode: Opcode - payload: Union[bytes, str, Tuple[int, str]] + payload: bytes | str | tuple[int, str] frame_finished: bool message_finished: bool @@ -246,12 +251,11 @@ def _truncate_utf8(data: bytes, nbytes: int) -> bytes: # whole message twice when in theory we could just peek at the last # few characters, but since this is only used for close messages (max # length = 125 bytes) it really doesn't matter. - data = data.decode("utf-8", errors="ignore").encode("utf-8") - return data + return data.decode("utf-8", errors="ignore").encode("utf-8") class Buffer: - def __init__(self, initial_bytes: Optional[bytes] = None) -> None: + def __init__(self, initial_bytes: bytes | None = None) -> None: self.buffer = bytearray() self.bytes_used = 0 if initial_bytes: @@ -260,7 +264,7 @@ def __init__(self, initial_bytes: Optional[bytes] = None) -> None: def feed(self, new_bytes: bytes) -> None: self.buffer += new_bytes - def consume_at_most(self, nbytes: int) -> bytes: + def consume_at_most(self, nbytes: int) -> bytearray: if not nbytes: return bytearray() @@ -268,7 +272,7 @@ def consume_at_most(self, nbytes: int) -> bytes: self.bytes_used += len(data) return data - def consume_exactly(self, nbytes: int) -> Optional[bytes]: + def consume_exactly(self, nbytes: int) -> bytearray | None: if len(self.buffer) - self.bytes_used < nbytes: return None @@ -288,18 +292,20 @@ def __len__(self) -> int: class MessageDecoder: def __init__(self) -> None: - self.opcode: Optional[Opcode] = None - self.decoder: Optional[IncrementalDecoder] = None + self.opcode: Opcode | None = None + self.decoder: IncrementalDecoder | None = None def process_frame(self, frame: Frame) -> Frame: assert not frame.opcode.iscontrol() if self.opcode is None: if frame.opcode is Opcode.CONTINUATION: - raise ParseFailed("unexpected CONTINUATION") + msg = "unexpected CONTINUATION" + raise ParseFailed(msg) self.opcode = frame.opcode elif frame.opcode is not Opcode.CONTINUATION: - raise ParseFailed("expected CONTINUATION, got %r" % frame.opcode) + msg = f"expected CONTINUATION, got {frame.opcode!r}" + raise ParseFailed(msg) if frame.opcode is Opcode.TEXT: self.decoder = getincrementaldecoder("utf-8")() @@ -326,26 +332,25 @@ def process_frame(self, frame: Frame) -> Frame: class FrameDecoder: def __init__( - self, client: bool, extensions: Optional[List["Extension"]] = None + self, client: bool, extensions: list[Extension] | None = None, ) -> None: self.client = client self.extensions = extensions or [] self.buffer = Buffer() - self.header: Optional[Header] = None - self.effective_opcode: Optional[Opcode] = None - self.masker: Union[None, XorMaskerNull, XorMaskerSimple] = None + self.header: Header | None = None + self.effective_opcode: Opcode | None = None + self.masker: None | XorMaskerNull | XorMaskerSimple = None self.payload_required = 0 self.payload_consumed = 0 def receive_bytes(self, data: bytes) -> None: self.buffer.feed(data) - def process_buffer(self) -> Optional[Frame]: - if not self.header: - if not self.parse_header(): - return None + def process_buffer(self) -> Frame | None: + if not self.header and not self.parse_header(): + return None # parse_header() sets these. assert self.header is not None assert self.masker is not None @@ -366,22 +371,24 @@ def process_buffer(self) -> Optional[Frame]: payload = self.masker.process(payload) for extension in self.extensions: - payload_ = extension.frame_inbound_payload_data(self, payload) + payload_ = extension.frame_inbound_payload_data(self, bytes(payload)) if isinstance(payload_, CloseReason): - raise ParseFailed("error in extension", payload_) - payload = payload_ + msg = "error in extension" + raise ParseFailed(msg, payload_) + payload = bytearray(payload_) if finished: final = bytearray() for extension in self.extensions: result = extension.frame_inbound_complete(self, self.header.fin) if isinstance(result, CloseReason): - raise ParseFailed("error in extension", result) + msg = "error in extension" + raise ParseFailed(msg, result) if result is not None: final += result payload += final - frame = Frame(self.effective_opcode, payload, finished, self.header.fin) + frame = Frame(self.effective_opcode, bytes(payload), finished, self.header.fin) if finished: self.header = None @@ -408,10 +415,12 @@ def parse_header(self) -> bool: try: opcode = Opcode(opcode) except ValueError: - raise ParseFailed(f"Invalid opcode {opcode:#x}") + msg = f"Invalid opcode {opcode:#x}" + raise ParseFailed(msg) if opcode.iscontrol() and not fin: - raise ParseFailed("Invalid attempt to fragment control frame") + msg = "Invalid attempt to fragment control frame" + raise ParseFailed(msg) has_mask = bool(data[1] & MASK_MASK) payload_len_short = data[1] & PAYLOAD_LEN_MASK @@ -423,9 +432,11 @@ def parse_header(self) -> bool: self.extension_processing(opcode, rsv, payload_len) if has_mask and self.client: - raise ParseFailed("client received unexpected masked frame") + msg = "client received unexpected masked frame" + raise ParseFailed(msg) if not has_mask and not self.client: - raise ParseFailed("server received unexpected unmasked frame") + msg = "server received unexpected unmasked frame" + raise ParseFailed(msg) if has_mask: masking_key = self.buffer.consume_exactly(4) if masking_key is None: @@ -446,18 +457,20 @@ def parse_header(self) -> bool: return True def parse_extended_payload_length( - self, opcode: Opcode, payload_len: int - ) -> Optional[int]: + self, opcode: Opcode, payload_len: int, + ) -> int | None: if opcode.iscontrol() and payload_len > MAX_PAYLOAD_NORMAL: - raise ParseFailed("Control frame with payload len > 125") + msg = "Control frame with payload len > 125" + raise ParseFailed(msg) if payload_len == PAYLOAD_LENGTH_TWO_BYTE: data = self.buffer.consume_exactly(2) if data is None: return None (payload_len,) = struct.unpack("!H", data) if payload_len <= MAX_PAYLOAD_NORMAL: + msg = "Payload length used 2 bytes when 1 would have sufficed" raise ParseFailed( - "Payload length used 2 bytes when 1 would have sufficed" + msg, ) elif payload_len == PAYLOAD_LENGTH_EIGHT_BYTE: data = self.buffer.consume_exactly(8) @@ -465,34 +478,38 @@ def parse_extended_payload_length( return None (payload_len,) = struct.unpack("!Q", data) if payload_len <= MAX_PAYLOAD_TWO_BYTE: + msg = "Payload length used 8 bytes when 2 would have sufficed" raise ParseFailed( - "Payload length used 8 bytes when 2 would have sufficed" + msg, ) if payload_len >> 63: # I'm not sure why this is illegal, but that's what the RFC # says, so... - raise ParseFailed("8-byte payload length with non-zero MSB") + msg = "8-byte payload length with non-zero MSB" + raise ParseFailed(msg) return payload_len def extension_processing( - self, opcode: Opcode, rsv: RsvBits, payload_len: int + self, opcode: Opcode, rsv: RsvBits, payload_len: int, ) -> None: rsv_used = [False, False, False] for extension in self.extensions: result = extension.frame_inbound_header(self, opcode, rsv, payload_len) if isinstance(result, CloseReason): - raise ParseFailed("error in extension", result) + msg = "error in extension" + raise ParseFailed(msg, result) for bit, used in enumerate(result): if used: rsv_used[bit] = True for expected, found in zip(rsv_used, rsv): if found and not expected: - raise ParseFailed("Reserved bit set unexpectedly") + msg = "Reserved bit set unexpectedly" + raise ParseFailed(msg) class FrameProtocol: - def __init__(self, client: bool, extensions: List["Extension"]) -> None: + def __init__(self, client: bool, extensions: list[Extension]) -> None: self.client = client self.extensions = [ext for ext in extensions if ext.enabled()] @@ -501,7 +518,7 @@ def __init__(self, client: bool, extensions: List["Extension"]) -> None: self._message_decoder = MessageDecoder() self._parse_more = self._parse_more_gen() - self._outbound_opcode: Optional[Opcode] = None + self._outbound_opcode: Opcode | None = None def _process_close(self, frame: Frame) -> Frame: data = frame.payload @@ -512,19 +529,21 @@ def _process_close(self, frame: Frame) -> Frame: # WebSocket Connection Close Code_ is considered to be 1005" data = (CloseReason.NO_STATUS_RCVD, "") elif len(data) == 1: - raise ParseFailed("CLOSE with 1 byte payload") + msg = "CLOSE with 1 byte payload" + raise ParseFailed(msg) else: (code,) = struct.unpack("!H", data[:2]) if code < MIN_CLOSE_REASON or code > MAX_CLOSE_REASON: - raise ParseFailed("CLOSE with invalid code") - try: + msg = "CLOSE with invalid code" + raise ParseFailed(msg) + with contextlib.suppress(ValueError): code = CloseReason(code) - except ValueError: - pass if code in LOCAL_ONLY_CLOSE_REASONS: - raise ParseFailed("remote CLOSE with local-only reason") + msg = "remote CLOSE with local-only reason" + raise ParseFailed(msg) if not isinstance(code, CloseReason) and code <= MAX_PROTOCOL_CLOSE_REASON: - raise ParseFailed("CLOSE with unknown reserved code") + msg = "CLOSE with unknown reserved code" + raise ParseFailed(msg) try: reason = data[2:].decode("utf-8") except UnicodeDecodeError as exc: @@ -536,7 +555,7 @@ def _process_close(self, frame: Frame) -> Frame: return Frame(frame.opcode, data, frame.frame_finished, frame.message_finished) - def _parse_more_gen(self) -> Generator[Optional[Frame], None, None]: + def _parse_more_gen(self) -> Generator[Frame | None, None, None]: # Consume as much as we can from self._buffer, yielding events, and # then yield None when we need more data. Or raise ParseFailed. @@ -567,44 +586,47 @@ def received_frames(self) -> Generator[Frame, None, None]: else: yield event - def close(self, code: Optional[int] = None, reason: Optional[str] = None) -> bytes: + def close(self, code: int | None = None, reason: str | None = None) -> bytearray: payload = bytearray() if code is CloseReason.NO_STATUS_RCVD: code = None if code is None and reason: - raise TypeError("cannot specify a reason without a code") + msg = "cannot specify a reason without a code" + raise TypeError(msg) if code in LOCAL_ONLY_CLOSE_REASONS: code = CloseReason.NORMAL_CLOSURE if code is not None: payload += bytearray(struct.pack("!H", code)) if reason is not None: payload += _truncate_utf8( - reason.encode("utf-8"), MAX_PAYLOAD_NORMAL - 2 + reason.encode("utf-8"), MAX_PAYLOAD_NORMAL - 2, ) return self._serialize_frame(Opcode.CLOSE, payload) - def ping(self, payload: bytes = b"") -> bytes: + def ping(self, payload: bytes = b"") -> bytearray: return self._serialize_frame(Opcode.PING, payload) - def pong(self, payload: bytes = b"") -> bytes: + def pong(self, payload: bytes = b"") -> bytearray: return self._serialize_frame(Opcode.PONG, payload) def send_data( - self, payload: Union[bytes, bytearray, str] = b"", fin: bool = True - ) -> bytes: + self, payload: bytes | bytearray | str = b"", fin: bool = True, + ) -> bytearray: if isinstance(payload, (bytes, bytearray, memoryview)): opcode = Opcode.BINARY elif isinstance(payload, str): opcode = Opcode.TEXT payload = payload.encode("utf-8") else: - raise ValueError("Must provide bytes or text") + msg = "Must provide bytes or text" + raise TypeError(msg) if self._outbound_opcode is None: self._outbound_opcode = opcode elif self._outbound_opcode is not opcode: - raise TypeError("Data type mismatch inside message") + msg = "Data type mismatch inside message" + raise TypeError(msg) else: opcode = Opcode.CONTINUATION @@ -621,11 +643,13 @@ def _make_fin_rsv_opcode(self, fin: bool, rsv: RsvBits, opcode: Opcode) -> int: return fin_bits | rsv_bits | opcode_bits def _serialize_frame( - self, opcode: Opcode, payload: bytes = b"", fin: bool = True - ) -> bytes: + self, opcode: Opcode, payload: bytes | bytearray = b"", fin: bool = True, + ) -> bytearray: + payload = bytearray(payload) + rsv = RsvBits(False, False, False) for extension in reversed(self.extensions): - rsv, payload = extension.frame_outbound(self, opcode, rsv, payload, fin) + rsv, payload = extension.frame_outbound(self, opcode, rsv, bytes(payload), fin) fin_rsv_opcode = self._make_fin_rsv_opcode(fin, rsv, opcode) @@ -648,7 +672,8 @@ def _serialize_frame( header = bytearray([fin_rsv_opcode, first_payload]) if second_payload is not None: if opcode.iscontrol(): - raise ValueError("payload too long for control frame") + msg = "payload too long for control frame" + raise ValueError(msg) if quad_payload: header += bytearray(struct.pack("!Q", second_payload)) else: @@ -666,8 +691,8 @@ def _serialize_frame( # authors of malicious applications from selecting the bytes that # appear on the wire." # -- https://tools.ietf.org/html/rfc6455#section-5.3 - masking_key = os.urandom(4) + masking_key = bytearray(os.urandom(4)) masker = XorMaskerSimple(masking_key) - return header + masking_key + masker.process(payload) + return bytearray(header + masking_key + masker.process(bytearray(payload))) - return header + payload + return bytearray(header + payload) diff --git a/src/wsproto/handshake.py b/src/wsproto/handshake.py index 28b0d8c3..ab390b8b 100644 --- a/src/wsproto/handshake.py +++ b/src/wsproto/handshake.py @@ -4,18 +4,12 @@ An implementation of WebSocket handshakes. """ +from __future__ import annotations from collections import deque from typing import ( + TYPE_CHECKING, cast, - Deque, - Dict, - Generator, - Iterable, - List, - Optional, - Sequence, - Union, ) import h11 @@ -23,16 +17,20 @@ from .connection import Connection, ConnectionState, ConnectionType from .events import AcceptConnection, Event, RejectConnection, RejectData, Request from .extensions import Extension -from .typing import Headers from .utilities import ( + LocalProtocolError, + RemoteProtocolError, generate_accept_token, generate_nonce, - LocalProtocolError, normed_header_dict, - RemoteProtocolError, split_comma_header, ) +if TYPE_CHECKING: + from collections.abc import Generator, Iterable, Sequence + + from .typing import Headers + # RFC6455, Section 4.2.1/6 - Reading the Client's Opening Handshake WEBSOCKET_VERSION = b"13" @@ -52,18 +50,19 @@ def __init__(self, connection_type: ConnectionType) -> None: else: self._h11_connection = h11.Connection(h11.SERVER) - self._connection: Optional[Connection] = None - self._events: Deque[Event] = deque() - self._initiating_request: Optional[Request] = None - self._nonce: Optional[bytes] = None + self._connection: Connection | None = None + self._events: deque[Event] = deque() + self._initiating_request: Request | None = None + self._nonce: bytes | None = None @property def state(self) -> ConnectionState: return self._state @property - def connection(self) -> Optional[Connection]: - """Return the established connection. + def connection(self) -> Connection | None: + """ + Return the established connection. This will either return the connection or raise a LocalProtocolError if the connection has not yet been @@ -74,9 +73,10 @@ def connection(self) -> Optional[Connection]: return self._connection def initiate_upgrade_connection( - self, headers: Headers, path: Union[bytes, str] + self, headers: Headers, path: bytes | str, ) -> None: - """Initiate an upgrade connection. + """ + Initiate an upgrade connection. This should be used if the request has already be received and parsed. @@ -85,15 +85,17 @@ def initiate_upgrade_connection( :param str path: A URL path. """ if self.client: + msg = "Cannot initiate an upgrade connection when acting as the client" raise LocalProtocolError( - "Cannot initiate an upgrade connection when acting as the client" + msg, ) upgrade_request = h11.Request(method=b"GET", target=path, headers=headers) h11_client = h11.Connection(h11.CLIENT) self.receive_data(h11_client.send(upgrade_request)) def send(self, event: Event) -> bytes: - """Send an event to the remote. + """ + Send an event to the remote. This will return the bytes to send based on the event or raise a LocalProtocolError if the event is not valid given the @@ -112,13 +114,15 @@ def send(self, event: Event) -> bytes: elif isinstance(event, RejectData): data += self._send_reject_data(event) else: + msg = f"Event {event} cannot be sent during the handshake" raise LocalProtocolError( - f"Event {event} cannot be sent during the handshake" + msg, ) return data - def receive_data(self, data: Optional[bytes]) -> None: - """Receive data from the remote. + def receive_data(self, data: bytes | None) -> None: + """ + Receive data from the remote. A list of events that the remote peer triggered by sending this data can be retrieved with :meth:`events`. @@ -130,8 +134,9 @@ def receive_data(self, data: Optional[bytes]) -> None: try: event = self._h11_connection.next_event() except h11.RemoteProtocolError: + msg = "Bad HTTP message" raise RemoteProtocolError( - "Bad HTTP message", event_hint=RejectConnection() + msg, event_hint=RejectConnection(), ) if ( isinstance(event, h11.ConnectionClosed) @@ -150,7 +155,7 @@ def receive_data(self, data: Optional[bytes]) -> None: headers=list(event.headers), status_code=event.status_code, has_body=False, - ) + ), ) self._state = ConnectionState.CLOSED elif isinstance(event, h11.Response): @@ -160,21 +165,21 @@ def receive_data(self, data: Optional[bytes]) -> None: headers=list(event.headers), status_code=event.status_code, has_body=True, - ) + ), ) elif isinstance(event, h11.Data): self._events.append( - RejectData(data=event.data, body_finished=False) + RejectData(data=event.data, body_finished=False), ) elif isinstance(event, h11.EndOfMessage): self._events.append(RejectData(data=b"", body_finished=True)) self._state = ConnectionState.CLOSED - else: - if isinstance(event, h11.Request): - self._events.append(self._process_connection_request(event)) + elif isinstance(event, h11.Request): + self._events.append(self._process_connection_request(event)) def events(self) -> Generator[Event, None, None]: - """Return a generator that provides any events that have been generated + """ + Return a generator that provides any events that have been generated by protocol activity. :returns: a generator that yields H11 events. @@ -184,18 +189,19 @@ def events(self) -> Generator[Event, None, None]: # Server mode methods - def _process_connection_request( # noqa: MC0001 - self, event: h11.Request + def _process_connection_request( + self, event: h11.Request, ) -> Request: if event.method != b"GET": + msg = "Request method must be GET" raise RemoteProtocolError( - "Request method must be GET", event_hint=RejectConnection() + msg, event_hint=RejectConnection(), ) connection_tokens = None - extensions: List[str] = [] + extensions: list[str] = [] host = None key = None - subprotocols: List[str] = [] + subprotocols: list[str] = [] upgrade = b"" version = None headers: Headers = [] @@ -222,29 +228,34 @@ def _process_connection_request( # noqa: MC0001 if connection_tokens is None or not any( token.lower() == "upgrade" for token in connection_tokens ): + msg = "Missing header, 'Connection: Upgrade'" raise RemoteProtocolError( - "Missing header, 'Connection: Upgrade'", event_hint=RejectConnection() + msg, event_hint=RejectConnection(), ) if version != WEBSOCKET_VERSION: + msg = "Missing header, 'Sec-WebSocket-Version'" raise RemoteProtocolError( - "Missing header, 'Sec-WebSocket-Version'", + msg, event_hint=RejectConnection( headers=[(b"Sec-WebSocket-Version", WEBSOCKET_VERSION)], status_code=426 if version else 400, ), ) if key is None: + msg = "Missing header, 'Sec-WebSocket-Key'" raise RemoteProtocolError( - "Missing header, 'Sec-WebSocket-Key'", event_hint=RejectConnection() + msg, event_hint=RejectConnection(), ) if upgrade.lower() != WEBSOCKET_UPGRADE: + msg = f"Missing header, 'Upgrade: {WEBSOCKET_UPGRADE.decode()}'" raise RemoteProtocolError( - f"Missing header, 'Upgrade: {WEBSOCKET_UPGRADE.decode()}'", + msg, event_hint=RejectConnection(), ) if host is None: + msg = "Missing header, 'Host'" raise RemoteProtocolError( - "Missing header, 'Host'", event_hint=RejectConnection() + msg, event_hint=RejectConnection(), ) self._initiating_request = Request( @@ -272,14 +283,15 @@ def _accept(self, event: AcceptConnection) -> bytes: if event.subprotocol is not None: if event.subprotocol not in self._initiating_request.subprotocols: - raise LocalProtocolError(f"unexpected subprotocol {event.subprotocol}") + msg = f"unexpected subprotocol {event.subprotocol}" + raise LocalProtocolError(msg) headers.append( - (b"Sec-WebSocket-Protocol", event.subprotocol.encode("ascii")) + (b"Sec-WebSocket-Protocol", event.subprotocol.encode("ascii")), ) if event.extensions: accepts = server_extensions_handshake( - cast(Sequence[str], self._initiating_request.extensions), + cast("Sequence[str]", self._initiating_request.extensions), event.extensions, ) if accepts: @@ -299,8 +311,9 @@ def _accept(self, event: AcceptConnection) -> bytes: def _reject(self, event: RejectConnection) -> bytes: if self.state != ConnectionState.CONNECTING: + msg = f"Connection cannot be rejected in state {self.state}" raise LocalProtocolError( - "Connection cannot be rejected in state %s" % self.state + msg, ) headers = list(event.headers) @@ -316,8 +329,9 @@ def _reject(self, event: RejectConnection) -> bytes: def _send_reject_data(self, event: RejectData) -> bytes: if self.state != ConnectionState.REJECTING: + msg = f"Cannot send rejection data in state {self.state}" raise LocalProtocolError( - f"Cannot send rejection data in state {self.state}" + msg, ) data = self._h11_connection.send(h11.Data(data=event.data)) or b"" @@ -345,11 +359,11 @@ def _initiate_connection(self, request: Request) -> bytes: ( b"Sec-WebSocket-Protocol", (", ".join(request.subprotocols)).encode("ascii"), - ) + ), ) if request.extensions: - offers: Dict[str, Union[str, bool]] = {} + offers: dict[str, str | bool] = {} for e in request.extensions: assert isinstance(e, Extension) offers[e.name] = e.offer() @@ -372,15 +386,15 @@ def _initiate_connection(self, request: Request) -> bytes: return self._h11_connection.send(upgrade) or b"" def _establish_client_connection( - self, event: h11.InformationalResponse - ) -> AcceptConnection: # noqa: MC0001 + self, event: h11.InformationalResponse, + ) -> AcceptConnection: # _establish_client_connection is always called after _initiate_connection. assert self._initiating_request is not None assert self._nonce is not None accept = None connection_tokens = None - accepts: List[str] = [] + accepts: list[str] = [] subprotocol = None upgrade = b"" headers: Headers = [] @@ -389,16 +403,16 @@ def _establish_client_connection( if name == b"connection": connection_tokens = split_comma_header(value) continue # Skip appending to headers - elif name == b"sec-websocket-extensions": + if name == b"sec-websocket-extensions": accepts = split_comma_header(value) continue # Skip appending to headers - elif name == b"sec-websocket-accept": + if name == b"sec-websocket-accept": accept = value continue # Skip appending to headers - elif name == b"sec-websocket-protocol": + if name == b"sec-websocket-protocol": subprotocol = value.decode("ascii") continue # Skip appending to headers - elif name == b"upgrade": + if name == b"upgrade": upgrade = value continue # Skip appending to headers headers.append((name, value)) @@ -406,25 +420,28 @@ def _establish_client_connection( if connection_tokens is None or not any( token.lower() == "upgrade" for token in connection_tokens ): + msg = "Missing header, 'Connection: Upgrade'" raise RemoteProtocolError( - "Missing header, 'Connection: Upgrade'", event_hint=RejectConnection() + msg, event_hint=RejectConnection(), ) if upgrade.lower() != WEBSOCKET_UPGRADE: + msg = f"Missing header, 'Upgrade: {WEBSOCKET_UPGRADE.decode()}'" raise RemoteProtocolError( - f"Missing header, 'Upgrade: {WEBSOCKET_UPGRADE.decode()}'", + msg, event_hint=RejectConnection(), ) accept_token = generate_accept_token(self._nonce) if accept != accept_token: - raise RemoteProtocolError("Bad accept token", event_hint=RejectConnection()) - if subprotocol is not None: - if subprotocol not in self._initiating_request.subprotocols: - raise RemoteProtocolError( - f"unrecognized subprotocol {subprotocol}", - event_hint=RejectConnection(), - ) + msg = "Bad accept token" + raise RemoteProtocolError(msg, event_hint=RejectConnection()) + if subprotocol is not None and subprotocol not in self._initiating_request.subprotocols: + msg = f"unrecognized subprotocol {subprotocol}" + raise RemoteProtocolError( + msg, + event_hint=RejectConnection(), + ) extensions = client_extensions_handshake( - accepts, cast(Sequence[Extension], self._initiating_request.extensions) + accepts, cast("Sequence[Extension]", self._initiating_request.extensions), ) self._connection = Connection( @@ -434,23 +451,22 @@ def _establish_client_connection( ) self._state = ConnectionState.OPEN return AcceptConnection( - extensions=extensions, extra_headers=headers, subprotocol=subprotocol + extensions=extensions, extra_headers=headers, subprotocol=subprotocol, ) def __repr__(self) -> str: - return "{}(client={}, state={})".format( - self.__class__.__name__, self.client, self.state - ) + return f"{self.__class__.__name__}(client={self.client}, state={self.state})" def server_extensions_handshake( - requested: Iterable[str], supported: List[Extension] -) -> Optional[bytes]: - """Agree on the extensions to use returning an appropriate header value. + requested: Iterable[str], supported: list[Extension], +) -> bytes | None: + """ + Agree on the extensions to use returning an appropriate header value. This returns None if there are no agreed extensions """ - accepts: Dict[str, Union[bool, bytes]] = {} + accepts: dict[str, bool | bytes] = {} for offer in requested: name = offer.split(";", 1)[0].strip() for extension in supported: @@ -463,25 +479,24 @@ def server_extensions_handshake( accepts[extension.name] = accept.encode("ascii") if accepts: - extensions: List[bytes] = [] + extensions: list[bytes] = [] for name, params in accepts.items(): name_bytes = name.encode("ascii") if isinstance(params, bool): assert params extensions.append(name_bytes) + elif params == b"": + extensions.append(b"%s" % (name_bytes)) else: - if params == b"": - extensions.append(b"%s" % (name_bytes)) - else: - extensions.append(b"%s; %s" % (name_bytes, params)) + extensions.append(b"%s; %s" % (name_bytes, params)) return b", ".join(extensions) return None def client_extensions_handshake( - accepted: Iterable[str], supported: Sequence[Extension] -) -> List[Extension]: + accepted: Iterable[str], supported: Sequence[Extension], +) -> list[Extension]: # This raises RemoteProtocolError is the accepted extension is not # supported. extensions = [] @@ -493,7 +508,8 @@ def client_extensions_handshake( extensions.append(extension) break else: + msg = f"unrecognized extension {name}" raise RemoteProtocolError( - f"unrecognized extension {name}", event_hint=RejectConnection() + msg, event_hint=RejectConnection(), ) return extensions diff --git a/src/wsproto/typing.py b/src/wsproto/typing.py index a44b27e5..0063d21f 100644 --- a/src/wsproto/typing.py +++ b/src/wsproto/typing.py @@ -1,3 +1,3 @@ -from typing import List, Tuple +from __future__ import annotations -Headers = List[Tuple[bytes, bytes]] +Headers = list[tuple[bytes, bytes]] diff --git a/src/wsproto/utilities.py b/src/wsproto/utilities.py index 967d94fb..d02117ae 100644 --- a/src/wsproto/utilities.py +++ b/src/wsproto/utilities.py @@ -4,16 +4,18 @@ Utility functions that do not belong in a separate module. """ +from __future__ import annotations import base64 import hashlib import os -from typing import Dict, List, Optional, Union +from typing import TYPE_CHECKING -from h11._headers import Headers as H11Headers +if TYPE_CHECKING: + from h11._headers import Headers as H11Headers -from .events import Event -from .typing import Headers + from .events import Event + from .typing import Headers # RFC6455, Section 1.3 - Opening Handshake ACCEPT_GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" @@ -24,7 +26,8 @@ class ProtocolError(Exception): class LocalProtocolError(ProtocolError): - """Indicates an error due to local/programming errors. + """ + Indicates an error due to local/programming errors. This is raised when the connection is asked to do something that is either incompatible with the state or the websocket standard. @@ -35,7 +38,8 @@ class LocalProtocolError(ProtocolError): class RemoteProtocolError(ProtocolError): - """Indicates an error due to the remote's actions. + """ + Indicates an error due to the remote's actions. This is raised when processing the bytes from the remote if the remote has sent data that is incompatible with the websocket @@ -48,18 +52,18 @@ class RemoteProtocolError(ProtocolError): """ - def __init__(self, message: str, event_hint: Optional[Event] = None) -> None: + def __init__(self, message: str, event_hint: Event | None = None) -> None: self.event_hint = event_hint super().__init__(message) # Some convenience utilities for working with HTTP headers -def normed_header_dict(h11_headers: Union[Headers, H11Headers]) -> Dict[bytes, bytes]: +def normed_header_dict(h11_headers: Headers | H11Headers) -> dict[bytes, bytes]: # This mangles Set-Cookie headers. But it happens that we don't care about # any of those, so it's OK. For every other HTTP header, if there are # multiple instances then you're allowed to join them together with # commas. - name_to_values: Dict[bytes, List[bytes]] = {} + name_to_values: dict[bytes, list[bytes]] = {} for name, value in h11_headers: name_to_values.setdefault(name, []).append(value) name_to_normed_value = {} @@ -73,7 +77,7 @@ def normed_header_dict(h11_headers: Union[Headers, H11Headers]) -> Dict[bytes, b # fine, because the ABNF is just 1#token. But for the extension lists, it's # wrong, because those can contain quoted strings, which can in turn contain # commas. XX FIXME -def split_comma_header(value: bytes) -> List[str]: +def split_comma_header(value: bytes) -> list[str]: return [piece.decode("ascii").strip() for piece in value.split(b",")] diff --git a/test/__init__.py b/tests/__init__.py similarity index 100% rename from test/__init__.py rename to tests/__init__.py diff --git a/test/helpers.py b/tests/helpers.py similarity index 95% rename from test/helpers.py rename to tests/helpers.py index 95dcc8bc..bdc9c213 100644 --- a/test/helpers.py +++ b/tests/helpers.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Optional, Union from wsproto.extensions import Extension diff --git a/test/test_client.py b/tests/test_client.py similarity index 94% rename from test/test_client.py rename to tests/test_client.py index 4de2cc33..d1bbdbce 100644 --- a/test/test_client.py +++ b/tests/test_client.py @@ -1,4 +1,6 @@ -from typing import cast, List, Optional +from __future__ import annotations + +from typing import List, Optional, cast import h11 import pytest @@ -15,11 +17,12 @@ from wsproto.extensions import Extension from wsproto.typing import Headers from wsproto.utilities import ( - generate_accept_token, LocalProtocolError, - normed_header_dict, RemoteProtocolError, + generate_accept_token, + normed_header_dict, ) + from .helpers import FakeExtension @@ -27,7 +30,7 @@ def _make_connection_request(request: Request) -> h11.Request: client = WSConnection(CLIENT) server = h11.Connection(h11.SERVER) server.receive_data(client.send(request)) - return cast(h11.Request, server.next_event()) + return cast("h11.Request", server.next_event()) def test_connection_request() -> None: @@ -50,7 +53,7 @@ def test_connection_request_additional_headers() -> None: host="localhost", target="/", extra_headers=[(b"X-Foo", b"Bar"), (b"X-Bar", b"Foo")], - ) + ), ) headers = normed_header_dict(request.headers) @@ -61,7 +64,7 @@ def test_connection_request_additional_headers() -> None: def test_connection_request_simple_extension() -> None: extension = FakeExtension(offer_response=True) request = _make_connection_request( - Request(host="localhost", target="/", extensions=[extension]) + Request(host="localhost", target="/", extensions=[extension]), ) headers = normed_header_dict(request.headers) @@ -71,7 +74,7 @@ def test_connection_request_simple_extension() -> None: def test_connection_request_simple_extension_no_offer() -> None: extension = FakeExtension(offer_response=False) request = _make_connection_request( - Request(host="localhost", target="/", extensions=[extension]) + Request(host="localhost", target="/", extensions=[extension]), ) headers = normed_header_dict(request.headers) @@ -82,7 +85,7 @@ def test_connection_request_parametrised_extension() -> None: offer_response = "parameter1=value1; parameter2=value2" extension = FakeExtension(offer_response=offer_response) request = _make_connection_request( - Request(host="localhost", target="/", extensions=[extension]) + Request(host="localhost", target="/", extensions=[extension]), ) headers = normed_header_dict(request.headers) @@ -94,7 +97,7 @@ def test_connection_request_parametrised_extension() -> None: def test_connection_request_subprotocols() -> None: request = _make_connection_request( - Request(host="localhost", target="/", subprotocols=["one", "two"]) + Request(host="localhost", target="/", subprotocols=["one", "two"]), ) headers = normed_header_dict(request.headers) @@ -111,10 +114,10 @@ def test_connection_send_state() -> None: Request( host="localhost", target="/", - ) - ) + ), + ), ) - headers = normed_header_dict(cast(h11.Request, server.next_event()).headers) + headers = normed_header_dict(cast("h11.Request", server.next_event()).headers) response = h11.InformationalResponse( status_code=101, headers=[ @@ -155,20 +158,20 @@ def _make_handshake( target="/", subprotocols=subprotocols or [], extensions=extensions or [], - ) - ) + ), + ), ) - request = cast(h11.Request, server.next_event()) + request = cast("h11.Request", server.next_event()) if auto_accept_key: full_request_headers = normed_header_dict(request.headers) response_headers.append( ( b"Sec-WebSocket-Accept", generate_accept_token(full_request_headers[b"sec-websocket-key"]), - ) + ), ) response = h11.InformationalResponse( - status_code=response_status, headers=response_headers + status_code=response_status, headers=response_headers, ) client.receive_data(server.send(response)) assert client.state is not ConnectionState.CONNECTING @@ -178,14 +181,14 @@ def _make_handshake( def test_handshake() -> None: events = _make_handshake( - 101, [(b"connection", b"Upgrade"), (b"upgrade", b"websocket")] + 101, [(b"connection", b"Upgrade"), (b"upgrade", b"websocket")], ) assert events == [AcceptConnection()] def test_broken_handshake() -> None: events = _make_handshake( - 102, [(b"connection", b"Upgrade"), (b"upgrade", b"websocket")] + 102, [(b"connection", b"Upgrade"), (b"upgrade", b"websocket")], ) assert isinstance(events[0], RejectConnection) assert events[0].status_code == 102 @@ -285,7 +288,7 @@ def test_protocol_error() -> None: def _make_handshake_rejection( - status_code: int, body: Optional[bytes] = None + status_code: int, body: Optional[bytes] = None, ) -> List[Event]: client = WSConnection(CLIENT) server = h11.Connection(h11.SERVER) @@ -294,7 +297,7 @@ def _make_handshake_rejection( if body is not None: headers.append(("Content-Length", str(len(body)))) client.receive_data( - server.send(h11.Response(status_code=status_code, headers=headers)) + server.send(h11.Response(status_code=status_code, headers=headers)), ) if body is not None: client.receive_data(server.send(h11.Data(data=body))) @@ -307,7 +310,7 @@ def test_handshake_rejection() -> None: events = _make_handshake_rejection(400) assert events == [ RejectConnection( - headers=[(b"connection", b"close")], has_body=True, status_code=400 + headers=[(b"connection", b"close")], has_body=True, status_code=400, ), RejectData(body_finished=True, data=b""), ] @@ -317,7 +320,7 @@ def test_handshake_rejection_with_body() -> None: events = _make_handshake_rejection(400, b"Hello") assert events == [ RejectConnection( - headers=[(b"content-length", b"5")], has_body=True, status_code=400 + headers=[(b"content-length", b"5")], has_body=True, status_code=400, ), RejectData(body_finished=False, data=b"Hello"), RejectData(body_finished=True, data=b""), diff --git a/test/test_connection.py b/tests/test_connection.py similarity index 97% rename from test/test_connection.py rename to tests/test_connection.py index bbedf550..92921dcc 100644 --- a/test/test_connection.py +++ b/tests/test_connection.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import pytest -from wsproto.connection import CLIENT, Connection, ConnectionState, SERVER +from wsproto.connection import CLIENT, SERVER, Connection, ConnectionState from wsproto.events import ( BytesMessage, CloseConnection, @@ -137,7 +139,7 @@ def test_data(split_message: bool) -> None: data = "ƒñö®∂😎" server.receive_data( - client.send(TextMessage(data=data, message_finished=not split_message)) + client.send(TextMessage(data=data, message_finished=not split_message)), ) event = next(server.events()) assert isinstance(event, TextMessage) diff --git a/test/test_extensions.py b/tests/test_extensions.py similarity index 92% rename from test/test_extensions.py rename to tests/test_extensions.py index 1b5e0c02..aa01a181 100644 --- a/test/test_extensions.py +++ b/tests/test_extensions.py @@ -1,6 +1,9 @@ +from __future__ import annotations + from typing import Union -from wsproto import extensions as wpext, frame_protocol as fp +from wsproto import extensions as wpext +from wsproto import frame_protocol as fp class ConcreteExtension(wpext.Extension): diff --git a/test/test_frame_protocol.py b/tests/test_frame_protocol.py similarity index 99% rename from test/test_frame_protocol.py rename to tests/test_frame_protocol.py index 36b53338..5443348f 100644 --- a/test/test_frame_protocol.py +++ b/tests/test_frame_protocol.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import itertools import struct from binascii import unhexlify @@ -6,7 +8,8 @@ import pytest -from wsproto import extensions as wpext, frame_protocol as fp +from wsproto import extensions as wpext +from wsproto import frame_protocol as fp class TestBuffer: @@ -392,7 +395,7 @@ def _split_message_test( assert frame.message_finished is True def _parse_failure_test( - self, client: bool, frame_bytes: bytes, close_reason: fp.CloseReason + self, client: bool, frame_bytes: bytes, close_reason: fp.CloseReason, ) -> None: decoder = fp.FrameDecoder(client=client) with pytest.raises(fp.ParseFailed) as excinfo: @@ -709,19 +712,19 @@ def frame_inbound_header( return fp.RsvBits(False, False, True) def frame_inbound_payload_data( - self, proto: Union[fp.FrameDecoder, fp.FrameProtocol], data: bytes + self, proto: Union[fp.FrameDecoder, fp.FrameProtocol], data: bytes, ) -> Union[bytes, fp.CloseReason]: self._inbound_payload_data_called = True if data == b"party time": return fp.CloseReason.POLICY_VIOLATION - elif data == b"ragequit": + if data == b"ragequit": self._fail_inbound_complete = True if self._inbound_rsv_bit_set: data = data.decode("utf-8").upper().encode("utf-8") return data def frame_inbound_complete( - self, proto: Union[fp.FrameDecoder, fp.FrameProtocol], fin: bool + self, proto: Union[fp.FrameDecoder, fp.FrameProtocol], fin: bool, ) -> Union[bytes, fp.CloseReason, None]: self._inbound_complete_called = True if self._fail_inbound_complete: @@ -1021,7 +1024,7 @@ def test_reasoned_close(self) -> None: proto = fp.FrameProtocol(client=False, extensions=[]) reason = r"¯\_(ツ)_/¯" expected_payload = struct.pack( - "!H", fp.CloseReason.NORMAL_CLOSURE + "!H", fp.CloseReason.NORMAL_CLOSURE, ) + reason.encode("utf8") data = proto.close(code=fp.CloseReason.NORMAL_CLOSURE, reason=reason) assert data == b"\x88" + bytearray([len(expected_payload)]) + expected_payload @@ -1194,7 +1197,7 @@ def test_data_we_have_no_idea_what_to_do_with(self) -> None: proto = fp.FrameProtocol(client=False, extensions=[]) payload: Dict[str, str] = dict() - with pytest.raises(ValueError): + with pytest.raises(TypeError): # Intentionally passing illegal type. proto.send_data(payload) # type: ignore diff --git a/test/test_handshake.py b/tests/test_handshake.py similarity index 96% rename from test/test_handshake.py rename to tests/test_handshake.py index 9b23bfcc..44973043 100644 --- a/test/test_handshake.py +++ b/tests/test_handshake.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import pytest -from wsproto.connection import CLIENT, ConnectionState, SERVER +from wsproto.connection import CLIENT, SERVER, ConnectionState from wsproto.events import AcceptConnection, Ping, Request from wsproto.handshake import H11Handshake from wsproto.utilities import LocalProtocolError, RemoteProtocolError @@ -44,7 +46,7 @@ def test_rejected_handshake(http: bytes) -> None: b"Connection: Upgrade\r\n" b"Sec-WebSocket-Key: VQr8cvwwZ1fEk62PDq8J3A==\r\n" b"Sec-WebSocket-Version: 13\r\n" - b"\r\n" + b"\r\n", ) @@ -79,7 +81,7 @@ def test_h11_multiple_headers_handshake() -> None: b"Sec-WebSocket-Extensions: this-extension; isnt-seen, even-tho, it-should-be\r\n" b"Sec-WebSocket-Protocol: there-protocols\r\n" b"Sec-WebSocket-Protocol: arent-seen\r\n" - b"Sec-WebSocket-Extensions: this-extension; were-gonna-see, and-another-extension; were-also; gonna-see=100; percent\r\n" # noqa: E501 + b"Sec-WebSocket-Extensions: this-extension; were-gonna-see, and-another-extension; were-also; gonna-see=100; percent\r\n" b"Sec-WebSocket-Protocol: only-these-protocols, are-seen, from-the-request-object\r\n" b"\r\n" ) diff --git a/test/test_permessage_deflate.py b/tests/test_permessage_deflate.py similarity index 96% rename from test/test_permessage_deflate.py rename to tests/test_permessage_deflate.py index 7202d6e2..b52b73dd 100644 --- a/test/test_permessage_deflate.py +++ b/tests/test_permessage_deflate.py @@ -1,12 +1,16 @@ +from __future__ import annotations + import zlib -from typing import cast, Optional, Sequence, TYPE_CHECKING +from collections.abc import Sequence +from typing import TYPE_CHECKING, Optional, cast import pytest -from wsproto import extensions as wpext, frame_protocol as fp +from wsproto import extensions as wpext +from wsproto import frame_protocol as fp if TYPE_CHECKING: - from mypy_extensions import TypedDict + from typing import TypedDict class Params(TypedDict, total=False): client_no_context_takeover: bool @@ -58,14 +62,14 @@ def make_offer_string(self, params: Params) -> str: offer.append("client_max_window_bits") else: offer.append( - "client_max_window_bits=%d" % params["client_max_window_bits"] + "client_max_window_bits=%d" % params["client_max_window_bits"], ) if "server_max_window_bits" in params: if params["server_max_window_bits"] is None: offer.append("server_max_window_bits") else: offer.append( - "server_max_window_bits=%d" % params["server_max_window_bits"] + "server_max_window_bits=%d" % params["server_max_window_bits"], ) if params.get("client_no_context_takeover", False): offer.append("client_no_context_takeover") @@ -75,7 +79,7 @@ def make_offer_string(self, params: Params) -> str: return "; ".join(offer) def compare_params_to_string( - self, params: Params, ext: wpext.PerMessageDeflate, param_string: str + self, params: Params, ext: wpext.PerMessageDeflate, param_string: str, ) -> None: if "client_max_window_bits" in params: if params["client_max_window_bits"] is None: @@ -98,7 +102,7 @@ def compare_params_to_string( def test_offer(self, params: Params) -> None: ext = wpext.PerMessageDeflate(**params) offer = ext.offer() - offer = cast(str, offer) + offer = cast("str", offer) self.compare_params_to_string(params, ext, offer) @@ -121,10 +125,10 @@ def test_finalize(self, params: Params) -> None: if params.get("server_max_window_bits", None): assert ext.server_max_window_bits == params["server_max_window_bits"] assert ext.client_no_context_takeover is params.get( - "client_no_context_takeover", False + "client_no_context_takeover", False, ) assert ext.server_no_context_takeover is params.get( - "server_no_context_takeover", False + "server_no_context_takeover", False, ) assert ext.enabled() @@ -144,7 +148,7 @@ def test_accept(self, params: Params) -> None: offer = self.make_offer_string(params) response = ext.accept(offer) - response = cast(str, response) + response = cast("str", response) if ext.client_no_context_takeover: assert "client_no_context_takeover" in response @@ -181,7 +185,7 @@ def test_inbound_uncompressed_control_frame(self) -> None: proto = fp.FrameProtocol(client=True, extensions=[ext]) result = ext.frame_inbound_header( - proto, fp.Opcode.PING, fp.RsvBits(False, False, False), len(payload) + proto, fp.Opcode.PING, fp.RsvBits(False, False, False), len(payload), ) assert isinstance(result, fp.RsvBits) assert result.rsv1 @@ -199,7 +203,7 @@ def test_inbound_compressed_control_frame(self) -> None: proto = fp.FrameProtocol(client=True, extensions=[ext]) result = ext.frame_inbound_header( - proto, fp.Opcode.PING, fp.RsvBits(True, False, False), len(payload) + proto, fp.Opcode.PING, fp.RsvBits(True, False, False), len(payload), ) assert result == fp.CloseReason.PROTOCOL_ERROR @@ -211,7 +215,7 @@ def test_inbound_compressed_continuation_frame(self) -> None: proto = fp.FrameProtocol(client=True, extensions=[ext]) result = ext.frame_inbound_header( - proto, fp.Opcode.CONTINUATION, fp.RsvBits(True, False, False), len(payload) + proto, fp.Opcode.CONTINUATION, fp.RsvBits(True, False, False), len(payload), ) assert result == fp.CloseReason.PROTOCOL_ERROR @@ -223,7 +227,7 @@ def test_inbound_uncompressed_data_frame(self) -> None: proto = fp.FrameProtocol(client=True, extensions=[ext]) result = ext.frame_inbound_header( - proto, fp.Opcode.BINARY, fp.RsvBits(False, False, False), len(payload) + proto, fp.Opcode.BINARY, fp.RsvBits(False, False, False), len(payload), ) assert isinstance(result, fp.RsvBits) assert result.rsv1 @@ -269,7 +273,7 @@ def test_client_inbound_compressed_multiple_data_frames(self, client: bool) -> N proto = fp.FrameProtocol(client=client, extensions=[ext]) result = ext.frame_inbound_header( - proto, fp.Opcode.BINARY, fp.RsvBits(True, False, False), split + proto, fp.Opcode.BINARY, fp.RsvBits(True, False, False), split, ) assert isinstance(result, fp.RsvBits) assert result.rsv1 @@ -304,7 +308,7 @@ def test_client_decompress_after_uncompressible_frame(self, client: bool) -> Non # A PING frame ext.frame_inbound_header( - proto, fp.Opcode.PING, fp.RsvBits(False, False, False), 0 + proto, fp.Opcode.PING, fp.RsvBits(False, False, False), 0, ) result2 = ext.frame_inbound_payload_data(proto, b"") assert not isinstance(result2, fp.CloseReason) @@ -388,7 +392,7 @@ def test_decompressor_reset(self, client: bool, no_context_takeover: bool) -> No proto = fp.FrameProtocol(client=client, extensions=[ext]) result = ext.frame_inbound_header( - proto, fp.Opcode.BINARY, fp.RsvBits(True, False, False), 0 + proto, fp.Opcode.BINARY, fp.RsvBits(True, False, False), 0, ) assert isinstance(result, fp.RsvBits) assert result.rsv1 @@ -404,7 +408,7 @@ def test_decompressor_reset(self, client: bool, no_context_takeover: bool) -> No assert ext._decompressor is not None result3 = ext.frame_inbound_header( - proto, fp.Opcode.BINARY, fp.RsvBits(True, False, False), 0 + proto, fp.Opcode.BINARY, fp.RsvBits(True, False, False), 0, ) assert isinstance(result3, fp.RsvBits) assert result3.rsv1 @@ -451,13 +455,13 @@ def test_outbound_compress_multiple_frames(self, client: bool) -> None: compressed_payload = b"\xaa\xa8\xc0\n\x00\x00" rsv, data = ext.frame_outbound( - proto, fp.Opcode.BINARY, rsv, payload[:split], False + proto, fp.Opcode.BINARY, rsv, payload[:split], False, ) assert rsv.rsv1 is True rsv = fp.RsvBits(False, False, False) rsv, more_data = ext.frame_outbound( - proto, fp.Opcode.CONTINUATION, rsv, payload[split:], True + proto, fp.Opcode.CONTINUATION, rsv, payload[split:], True, ) assert rsv.rsv1 is False assert data + more_data == compressed_payload diff --git a/test/test_server.py b/tests/test_server.py similarity index 95% rename from test/test_server.py rename to tests/test_server.py index 47627b48..10e3e7f8 100644 --- a/test/test_server.py +++ b/tests/test_server.py @@ -1,4 +1,6 @@ -from typing import cast, List, Optional, Tuple +from __future__ import annotations + +from typing import List, Optional, Tuple, cast import h11 import pytest @@ -9,11 +11,12 @@ from wsproto.extensions import Extension from wsproto.typing import Headers from wsproto.utilities import ( + RemoteProtocolError, generate_accept_token, generate_nonce, normed_header_dict, - RemoteProtocolError, ) + from .helpers import FakeExtension @@ -21,7 +24,7 @@ def _make_connection_request(request_headers: Headers, method: str = "GET") -> R client = h11.Connection(h11.CLIENT) server = WSConnection(SERVER) server.receive_data( - client.send(h11.Request(method=method, target="/", headers=request_headers)) + client.send(h11.Request(method=method, target="/", headers=request_headers)), ) event = next(server.events()) assert isinstance(event, Request) @@ -37,7 +40,7 @@ def test_connection_request() -> None: (b"Sec-WebSocket-Version", b"13"), (b"Sec-WebSocket-Key", generate_nonce()), (b"X-Foo", b"bar"), - ] + ], ) assert event.extensions == [] @@ -78,7 +81,7 @@ def test_connection_request_bad_connection_header() -> None: (b"Upgrade", b"websocket"), (b"Sec-WebSocket-Version", b"13"), (b"Sec-WebSocket-Key", generate_nonce()), - ] + ], ) assert str(excinfo.value) == "Missing header, 'Connection: Upgrade'" @@ -92,7 +95,7 @@ def test_connection_request_bad_upgrade_header() -> None: (b"Upgrade", b"h2c"), (b"Sec-WebSocket-Version", b"13"), (b"Sec-WebSocket-Key", generate_nonce()), - ] + ], ) assert str(excinfo.value) == "Missing header, 'Upgrade: websocket'" @@ -107,11 +110,11 @@ def test_connection_request_bad_version_header(version: bytes) -> None: (b"Upgrade", b"websocket"), (b"Sec-WebSocket-Version", version), (b"Sec-WebSocket-Key", generate_nonce()), - ] + ], ) assert str(excinfo.value) == "Missing header, 'Sec-WebSocket-Version'" assert excinfo.value.event_hint == RejectConnection( - headers=[(b"Sec-WebSocket-Version", b"13")], status_code=426 + headers=[(b"Sec-WebSocket-Version", b"13")], status_code=426, ) @@ -123,7 +126,7 @@ def test_connection_request_key_header() -> None: (b"Connection", b"Keep-Alive, Upgrade"), (b"Upgrade", b"websocket"), (b"Sec-WebSocket-Version", b"13"), - ] + ], ) assert str(excinfo.value) == "Missing header, 'Sec-WebSocket-Key'" @@ -142,7 +145,7 @@ def test_upgrade_request() -> None: "/", ) event = next(server.events()) - event = cast(Request, event) + event = cast("Request", event) assert event.extensions == [] assert event.host == "localhost" @@ -180,8 +183,8 @@ def _make_handshake( (b"Sec-WebSocket-Key", nonce), ] + request_headers, - ) - ) + ), + ), ) client.receive_data( server.send( @@ -189,11 +192,11 @@ def _make_handshake( extra_headers=accept_headers or [], subprotocol=subprotocol, extensions=extensions or [], - ) - ) + ), + ), ) event = client.next_event() - return cast(h11.InformationalResponse, event), nonce + return cast("h11.InformationalResponse", event), nonce def test_handshake() -> None: @@ -222,7 +225,7 @@ def test_handshake_extra_headers() -> None: @pytest.mark.parametrize("accept_subprotocol", ["one", "two"]) def test_handshake_with_subprotocol(accept_subprotocol: str) -> None: response, _ = _make_handshake( - [(b"Sec-Websocket-Protocol", b"one, two")], subprotocol=accept_subprotocol + [(b"Sec-Websocket-Protocol", b"one, two")], subprotocol=accept_subprotocol, ) headers = normed_header_dict(response.headers) @@ -249,7 +252,7 @@ def test_handshake_with_extension_params() -> None: ( b"Sec-Websocket-Extensions", (f"{extension.name}; {offered_params}").encode("ascii"), - ) + ), ], extensions=[extension], ) @@ -268,7 +271,7 @@ def test_handshake_with_extra_unaccepted_extension() -> None: ( b"Sec-Websocket-Extensions", b"pretend, %s" % extension.name.encode("ascii"), - ) + ), ], extensions=[extension], ) @@ -285,7 +288,7 @@ def test_protocol_error() -> None: def _make_handshake_rejection( - status_code: int, body: Optional[bytes] = None + status_code: int, body: Optional[bytes] = None, ) -> List[h11.Event]: client = h11.Connection(h11.CLIENT) server = WSConnection(SERVER) @@ -302,8 +305,8 @@ def _make_handshake_rejection( (b"Sec-WebSocket-Version", b"13"), (b"Sec-WebSocket-Key", nonce), ], - ) - ) + ), + ), ) if body is not None: client.receive_data( @@ -312,8 +315,8 @@ def _make_handshake_rejection( headers=[(b"content-length", b"%d" % len(body))], status_code=status_code, has_body=True, - ) - ) + ), + ), ) client.receive_data(server.send(RejectData(data=body))) else: @@ -321,7 +324,7 @@ def _make_handshake_rejection( events = [] while True: event = client.next_event() - events.append(cast(h11.Event, event)) + events.append(cast("h11.Event", event)) if isinstance(event, h11.EndOfMessage): return events diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 26e88682..00000000 --- a/tox.ini +++ /dev/null @@ -1,69 +0,0 @@ -[tox] -envlist = py37, py38, py39, py310, py311, pypy3, lint, docs, packaging - -[gh-actions] -python = - 3.7: py37 - 3.8: py38 - 3.9: py39 - 3.10: py310 - 3.11: py311, lint, docs, packaging - pypy3: pypy3 - -[testenv] -passenv = - GITHUB_* -deps = - pytest - pytest-cov - pytest-xdist -commands = - pytest --cov-report=xml --cov-report=term --cov=wsproto {posargs} - -[testenv:pypy3] -# temporarily disable coverage testing on PyPy due to performance problems -commands = pytest {posargs} - -[testenv:lint] -deps = - flake8 - black - isort - mypy - {[testenv]deps} -commands = - flake8 src/ test/ - black --check --diff src/ test/ example/ bench/ - isort --check --diff src/ test/ example/ bench/ - mypy src/ test/ example/ bench/ - -[testenv:docs] -deps = - -r docs/source/requirements.txt -allowlist_externals = make -changedir = {toxinidir}/docs -commands = - make clean - make html - -[testenv:packaging] -basepython = python3.10 -deps = - check-manifest - readme-renderer - twine -allowlist_externals = rm -commands = - rm -rf dist/ - check-manifest - python setup.py sdist bdist_wheel - twine check dist/* - -[testenv:publish] -basepython = {[testenv:packaging]basepython} -deps = - {[testenv:packaging]deps} -allowlist_externals = {[testenv:packaging]allowlist_externals} -commands = - {[testenv:packaging]commands} - twine upload dist/*