Skip to content

Commit

Permalink
Merge pull request #458 from consideRatio/test-subprotocols
Browse files Browse the repository at this point in the history
Ensure no blank `Sec-Websocket-Protocol` headers and warn if websocket subprotocol edge case occur
  • Loading branch information
consideRatio authored Feb 23, 2024
2 parents 288c74b + eda6136 commit 1b9f84b
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 46 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ jobs:
pip-install-constraints: >-
jupyter-server==1.0
simpervisor==1.0
tornado==5.0
tornado==5.1
traitlets==4.2.1
steps:
Expand Down
39 changes: 32 additions & 7 deletions jupyter_server_proxy/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ def __init__(self, *args, **kwargs):
"rewrite_response",
tuple(),
)
self.subprotocols = None
super().__init__(*args, **kwargs)

# Support/use jupyter_server config arguments allow_origin and allow_origin_pat
Expand Down Expand Up @@ -489,15 +488,28 @@ async def start_websocket_connection():
self.log.info(f"Trying to establish websocket connection to {client_uri}")
self._record_activity()
request = httpclient.HTTPRequest(url=client_uri, headers=headers)
subprotocols = (
[self.selected_subprotocol] if self.selected_subprotocol else None
)
self.ws = await pingable_ws_connect(
request=request,
on_message_callback=message_cb,
on_ping_callback=ping_cb,
subprotocols=self.subprotocols,
subprotocols=subprotocols,
resolver=resolver,
)
self._record_activity()
self.log.info(f"Websocket connection established to {client_uri}")
if (
subprotocols
and self.ws.selected_subprotocol != self.selected_subprotocol
):
self.log.warn(
f"Websocket subprotocol between proxy/server ({self.ws.selected_subprotocol}) "
f"became different than for client/proxy ({self.selected_subprotocol}) "
"due to https://github.com/jupyterhub/jupyter-server-proxy/issues/459. "
f"Requested subprotocols were {subprotocols}."
)

# Wait for the WebSocket to be connected before resolving.
# Otherwise, messages sent by the client before the
Expand Down Expand Up @@ -531,12 +543,25 @@ def check_xsrf_cookie(self):
"""

def select_subprotocol(self, subprotocols):
"""Select a single Sec-WebSocket-Protocol during handshake."""
self.subprotocols = subprotocols
if isinstance(subprotocols, list) and subprotocols:
self.log.debug(f"Client sent subprotocols: {subprotocols}")
"""
Select a single Sec-WebSocket-Protocol during handshake.
Note that this subprotocol selection should really be delegated to the
server we proxy to, but we don't! For this to happen, we would need to
delay accepting the handshake with the client until we have successfully
handshaked with the server. This issue is tracked via
https://github.com/jupyterhub/jupyter-server-proxy/issues/459.
Overrides `tornado.websocket.WebSocketHandler.select_subprotocol` that
includes an informative docstring:
https://github.com/tornadoweb/tornado/blob/v6.4.0/tornado/websocket.py#L337-L360.
"""
if subprotocols:
self.log.debug(
f"Client sent subprotocols: {subprotocols}, selecting the first"
)
return subprotocols[0]
return super().select_subprotocol(subprotocols)
return None


class LocalProxyHandler(ProxyHandler):
Expand Down
23 changes: 18 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,14 @@ dependencies = [
"importlib_metadata >=4.8.3 ; python_version<\"3.10\"",
"jupyter-server >=1.0",
"simpervisor >=1.0",
"tornado >=5.0",
"tornado >=5.1",
"traitlets >= 4.2.1",
]

[project.optional-dependencies]
test = [
"pytest",
"pytest-asyncio",
"pytest-cov",
"pytest-html",
]
Expand Down Expand Up @@ -195,21 +196,33 @@ src = "pyproject.toml"
[[tool.tbump.file]]
src = "labextension/package.json"


# pytest is used for running Python based tests
#
# ref: https://docs.pytest.org/en/stable/
#
[tool.pytest.ini_options]
cache_dir = "build/.cache/pytest"
testpaths = ["tests"]
addopts = [
"-vv",
"--verbose",
"--durations=10",
"--color=yes",
"--cov=jupyter_server_proxy",
"--cov-branch",
"--cov-context=test",
"--cov-report=term-missing:skip-covered",
"--cov-report=html:build/coverage",
"--no-cov-on-fail",
"--html=build/pytest/index.html",
"--color=yes",
]
asyncio_mode = "auto"
testpaths = ["tests"]
cache_dir = "build/.cache/pytest"


# pytest-cov / coverage is used to measure code coverage of tests
#
# ref: https://coverage.readthedocs.io/en/stable/config.html
#
[tool.coverage.run]
data_file = "build/.coverage"
concurrency = [
Expand Down
30 changes: 26 additions & 4 deletions tests/resources/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,26 +54,48 @@ def get(self):


class EchoWebSocket(tornado.websocket.WebSocketHandler):
"""Echoes back received messages."""

def on_message(self, message):
self.write_message(message)


class HeadersWebSocket(tornado.websocket.WebSocketHandler):
"""Echoes back incoming request headers."""

def on_message(self, message):
self.write_message(json.dumps(dict(self.request.headers)))


class SubprotocolWebSocket(tornado.websocket.WebSocketHandler):
"""
Echoes back requested subprotocols and selected subprotocol as a JSON
encoded message, and selects subprotocols in a very particular way to help
us test things.
"""

def __init__(self, *args, **kwargs):
self._subprotocols = None
self._requested_subprotocols = None
super().__init__(*args, **kwargs)

def select_subprotocol(self, subprotocols):
self._subprotocols = subprotocols
return None
self._requested_subprotocols = subprotocols if subprotocols else None

if not subprotocols:
return None
if "please_select_no_protocol" in subprotocols:
return None
if "favored" in subprotocols:
return "favored"
else:
return subprotocols[0]

def on_message(self, message):
self.write_message(json.dumps(self._subprotocols))
response = {
"requested_subprotocols": self._requested_subprotocols,
"selected_subprotocol": self.selected_subprotocol,
}
self.write_message(json.dumps(response))


def main():
Expand Down
91 changes: 62 additions & 29 deletions tests/test_proxies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
import gzip
import json
import sys
Expand Down Expand Up @@ -332,14 +331,9 @@ def test_server_content_encoding_header(
assert f.read() == b"this is a test"


@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.get_event_loop()
yield loop
loop.close()


async def _websocket_echo(a_server_port_and_token: Tuple[int, str]) -> None:
async def test_server_proxy_websocket_messages(
a_server_port_and_token: Tuple[int, str]
) -> None:
PORT = a_server_port_and_token[0]
url = f"ws://{LOCALHOST}:{PORT}/python-websocket/echosocket"
conn = await websocket_connect(url)
Expand All @@ -349,13 +343,7 @@ async def _websocket_echo(a_server_port_and_token: Tuple[int, str]) -> None:
assert msg == expected_msg


def test_server_proxy_websocket(
event_loop, a_server_port_and_token: Tuple[int, str]
) -> None:
event_loop.run_until_complete(_websocket_echo(a_server_port_and_token))


async def _websocket_headers(a_server_port_and_token: Tuple[int, str]) -> None:
async def test_server_proxy_websocket_headers(a_server_port_and_token: Tuple[int, str]):
PORT = a_server_port_and_token[0]
url = f"ws://{LOCALHOST}:{PORT}/python-websocket/headerssocket"
conn = await websocket_connect(url)
Expand All @@ -366,25 +354,68 @@ async def _websocket_headers(a_server_port_and_token: Tuple[int, str]) -> None:
assert headers["X-Custom-Header"] == "pytest-23456"


def test_server_proxy_websocket_headers(
event_loop, a_server_port_and_token: Tuple[int, str]
@pytest.mark.parametrize(
"client_requested,server_received,server_responded,proxy_responded",
[
(None, None, None, None),
(["first"], ["first"], "first", "first"),
# IMPORTANT: The tests below verify current bugged behavior, and the
# commented out tests is what we want to succeed!
#
# The proxy websocket should actually respond the handshake
# with a subprotocol based on a the server handshake
# response, but we are finalizing the client/proxy handshake
# before the proxy/server handshake, and that makes it
# impossible. We currently instead just pick the first
# requested protocol no matter what what subprotocol the
# server picks.
#
# Bug 1 - server wasn't passed all subprotocols:
(["first", "second"], ["first"], "first", "first"),
# (["first", "second"], ["first", "second"], "first", "first"),
#
# Bug 2 - server_responded doesn't match proxy_responded:
(["first", "favored"], ["first"], "first", "first"),
# (["first", "favored"], ["first", "favored"], "favored", "favored"),
(
["please_select_no_protocol"],
["please_select_no_protocol"],
None,
"please_select_no_protocol",
),
# (["please_select_no_protocol"], ["please_select_no_protocol"], None, None),
],
)
async def test_server_proxy_websocket_subprotocols(
a_server_port_and_token: Tuple[int, str],
client_requested,
server_received,
server_responded,
proxy_responded,
):
event_loop.run_until_complete(_websocket_headers(a_server_port_and_token))


async def _websocket_subprotocols(a_server_port_and_token: Tuple[int, str]) -> None:
PORT, TOKEN = a_server_port_and_token
url = f"ws://{LOCALHOST}:{PORT}/python-websocket/subprotocolsocket"
conn = await websocket_connect(url, subprotocols=["protocol_1", "protocol_2"])
conn = await websocket_connect(url, subprotocols=client_requested)
await conn.write_message("Hello, world!")

# verify understanding of websocket_connect that this test relies on
if client_requested:
assert "Sec-Websocket-Protocol" in conn.request.headers
else:
assert "Sec-Websocket-Protocol" not in conn.request.headers

msg = await conn.read_message()
assert json.loads(msg) == ["protocol_1", "protocol_2"]
info = json.loads(msg)

assert info["requested_subprotocols"] == server_received
assert info["selected_subprotocol"] == server_responded
assert conn.selected_subprotocol == proxy_responded

def test_server_proxy_websocket_subprotocols(
event_loop, a_server_port_and_token: Tuple[int, str]
):
event_loop.run_until_complete(_websocket_subprotocols(a_server_port_and_token))
# verify proxy response headers directly
if proxy_responded is None:
assert "Sec-Websocket-Protocol" not in conn.headers
else:
assert "Sec-Websocket-Protocol" in conn.headers


@pytest.mark.parametrize(
Expand All @@ -410,7 +441,9 @@ def test_bad_server_proxy_url(
assert "X-ProxyContextPath" not in r.headers


def test_callable_environment_formatting(a_server_port_and_token: Tuple[int, str]) -> None:
def test_callable_environment_formatting(
a_server_port_and_token: Tuple[int, str]
) -> None:
PORT, TOKEN = a_server_port_and_token
r = request_get(PORT, "/python-http-callable-env/test", TOKEN)
assert r.code == 200

0 comments on commit 1b9f84b

Please sign in to comment.