diff --git a/docs/getting_started/configuration.rst b/docs/getting_started/configuration.rst index 4159530d5..44e648630 100644 --- a/docs/getting_started/configuration.rst +++ b/docs/getting_started/configuration.rst @@ -23,10 +23,10 @@ Environment variables have the highest priority, followed by keyword arguments t Using an HTTP or HTTPS proxy with PRAW -------------------------------------- -PRAW internally relies upon the requests_ package to handle HTTP requests. Requests +PRAW internally relies upon the niquests_ package to handle HTTP requests. Niquests supports use of ``HTTP_PROXY`` and ``HTTPS_PROXY`` environment variables in order to proxy HTTP and HTTPS requests respectively [`ref -`_]. +`_]. Given that PRAW exclusively communicates with Reddit via HTTPS, only the ``HTTPS_PROXY`` option should be required. @@ -41,19 +41,19 @@ variable can be provided on the command line like so: Configuring a custom requests Session ------------------------------------- -PRAW uses requests_ to handle networking. If your use-case requires custom +PRAW uses niquests_ to handle networking. If your use-case requires custom configuration, it is possible to configure a custom Session_ instance and then use it with PRAW. For example, some networks use self-signed SSL certificates when connecting to HTTPS -sites. By default, this would raise an exception in requests_. To use a self-signed SSL -certificate without an exception from requests_, first export the certificate as a +sites. By default, this would raise an exception in niquests_. To use a self-signed SSL +certificate without an exception from niquests_, first export the certificate as a ``.pem`` file. Then configure PRAW like so: .. code-block:: python import praw - from requests import Session + from niquests import Session session = Session() @@ -69,11 +69,11 @@ certificate without an exception from requests_, first export the certificate as The code above creates a custom Session_ instance and `configures it to use a custom certificate -`_, then -passes it as a parameter when creating the :class:`.Reddit` instance. Note that the +`_, +then passes it as a parameter when creating the :class:`.Reddit` instance. Note that the example above uses a :ref:`password_flow` authentication type, but this method will work for any authentication type. -.. _requests: https://requests.readthedocs.io +.. _niquests: https://niquests.readthedocs.io -.. _session: https://2.python-requests.org/en/master/api/#requests.Session +.. _session: https://niquests.readthedocs.io/en/latest/user/advanced.html diff --git a/docs/getting_started/multiple_instances.rst b/docs/getting_started/multiple_instances.rst index f9dd2683e..237f042c5 100644 --- a/docs/getting_started/multiple_instances.rst +++ b/docs/getting_started/multiple_instances.rst @@ -52,8 +52,8 @@ Multiple Threads PRAW is not thread safe. In a nutshell, instances of :class:`.Reddit` are not thread-safe for a number of reasons -in its own code and each instance depends on an instance of ``requests.Session``, which -is not thread-safe [`ref `_]. +in its own code and each instance depends on an instance of ``niquests.Session``, which +is thread-safe. In theory, having a unique :class:`.Reddit` instance for each thread, and making sure that the instances are used in their respective threads only, will work. diff --git a/docs/tutorials/refresh_token.rst b/docs/tutorials/refresh_token.rst index 2a78c41d9..1d8ab7b35 100644 --- a/docs/tutorials/refresh_token.rst +++ b/docs/tutorials/refresh_token.rst @@ -14,9 +14,9 @@ following: .. code-block:: python - import requests + import niquests - response = requests.get( + response = niquests.get( "https://www.reddit.com/api/v1/scopes.json", headers={"User-Agent": "fetch-scopes by u/bboe"}, ) diff --git a/praw/models/reddit/subreddit.py b/praw/models/reddit/subreddit.py index 643e809d6..326b3e783 100644 --- a/praw/models/reddit/subreddit.py +++ b/praw/models/reddit/subreddit.py @@ -13,10 +13,9 @@ from warnings import warn from xml.etree.ElementTree import XML -import websocket +from niquests.exceptions import HTTPError, ReadTimeout, Timeout from prawcore import Redirect from prawcore.exceptions import ServerError -from requests.exceptions import HTTPError from ...const import API_PATH, JPEG_HEADER from ...exceptions import ( @@ -3067,30 +3066,40 @@ def _submit_media( """ response = self._reddit.post(API_PATH["submit"], data=data) websocket_url = response["json"]["data"]["websocket_url"] - connection = None + ws_response = None if websocket_url is not None and not without_websockets: try: - connection = websocket.create_connection(websocket_url, timeout=timeout) + ws_response = self._reddit._core._requestor._http.get( + websocket_url, + timeout=timeout, + ).raise_for_status() except ( - OSError, - websocket.WebSocketException, - BlockingIOError, + HTTPError, + Timeout, ) as ws_exception: msg = "Error establishing websocket connection." raise WebSocketException(msg, ws_exception) from None - if connection is None: + if ws_response is None: return None try: - ws_update = loads(connection.recv()) - connection.close() - except (OSError, websocket.WebSocketException, BlockingIOError) as ws_exception: + ws_update = loads(ws_response.extension.next_payload()) + except ( + ReadTimeout, + HTTPError, + ) as ws_exception: msg = "Websocket error. Check your media file. Your post may still have been created." raise WebSocketException( msg, ws_exception, ) from None + finally: + if ( + ws_response.extension is not None + and ws_response.extension.closed is False + ): + ws_response.extension.close() if ws_update.get("type") == "failed": raise MediaPostFailed url = ws_update["payload"]["redirect"] diff --git a/pyproject.toml b/pyproject.toml index 2d2436d87..760171e97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,9 +21,9 @@ classifiers = [ "Topic :: Utilities" ] dependencies = [ - "prawcore >=2.4, <3", + "prawcore@git+https://github.com/Ousret/prawcore@feat-niquests", "update_checker >=0.18", - "websocket-client >=0.54.0" + "niquests[ws]>=3.10,<4" ] dynamic = ["version", "description"] keywords = ["reddit", "api", "wrapper"] @@ -56,9 +56,7 @@ readthedocs = [ test = [ "betamax >=0.8, <0.9", "betamax-matchers >=0.3.0, <0.5", - "pytest >=2.7.3", - "requests >=2.20.1, <3", - "urllib3 ==1.26.*, <2" + "pytest >=2.7.3" ] [project.urls] @@ -78,6 +76,11 @@ extend_exclude = ['./docs/examples/'] profile = 'black' skip_glob = '.venv*' +[tool.pytest.ini_options] +# this avoids pytest loading betamax+Requests at boot. +# this allows us to patch betamax and makes it use Niquests instead. +addopts = "-p no:pytest-betamax" + [tool.ruff] target-version = "py38" include = [ diff --git a/tests/conftest.py b/tests/conftest.py index 7e003b831..b58e7391c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,51 @@ import socket import time from base64 import b64encode -from sys import platform +from sys import platform, modules +from urllib.parse import urlparse + +import requests +import niquests +import urllib3 + +# betamax is tied to Requests +# and Niquests is almost entirely compatible with it. +# we can fool it without effort. +modules["requests"] = niquests +modules["requests.adapters"] = niquests.adapters +modules["requests.models"] = niquests.models +modules["requests.exceptions"] = niquests.exceptions +modules["requests.packages.urllib3"] = urllib3 + +# niquests no longer have a compat submodule +# but betamax need it. no worries, as betamax +# explicitly need requests, we'll give it to him. +modules["requests.compat"] = requests.compat + +# doing the import now will make betamax working with Niquests! +# no extra effort. +import betamax + +# the base mock does not implement close(), which is required +# for our HTTP client. No biggy. +betamax.mock_response.MockHTTPResponse.close = lambda _: None + + +# betamax have a tiny bug in URI matcher +# https://example.com != https://example.com/ +# And Niquests does not enforce the trailing '/' +# when preparing a Request. +def _patched_parse(self, uri): + parsed = urlparse(uri) + return { + "scheme": parsed.scheme, + "netloc": parsed.netloc, + "path": parsed.path or "/", + "fragment": parsed.fragment, + } + + +betamax.matchers.uri.URIMatcher.parse = _patched_parse import pytest @@ -31,6 +75,42 @@ def _get_path(name): return _get_path +@pytest.fixture(autouse=True) +def lax_content_length_strict(monkeypatch): + import io + import base64 + from betamax.util import body_io + from urllib3 import HTTPResponse + from betamax.mock_response import MockHTTPResponse + + # our cassettes are[...] pretty much broken. + # Some declared Content-Length don't match the bodies. + # Let's disable enforced content-length here. + def _patched_add_urllib3_response(serialized, response, headers): + if "base64_string" in serialized["body"]: + body = io.BytesIO( + base64.b64decode(serialized["body"]["base64_string"].encode()) + ) + else: + body = body_io(**serialized["body"]) + + h = HTTPResponse( + body, + status=response.status_code, + reason=response.reason, + headers=headers, + preload_content=False, + original_response=MockHTTPResponse(headers), + enforce_content_length=False, + ) + + response.raw = h + + monkeypatch.setattr( + betamax.util, "add_urllib3_response", _patched_add_urllib3_response + ) + + def pytest_configure(config): pytest.placeholders = Placeholders(placeholders) config.addinivalue_line( diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py index f0f68eb2d..83e76c6f8 100644 --- a/tests/integration/__init__.py +++ b/tests/integration/__init__.py @@ -4,8 +4,8 @@ from urllib.parse import quote_plus import betamax +import niquests import pytest -import requests from betamax.cassette import Cassette from praw import Reddit @@ -70,7 +70,7 @@ def read_only(self, reddit): @pytest.fixture(autouse=True) def recorder(self): """Configure Betamax.""" - session = requests.Session() + session = niquests.Session() recorder = betamax.Betamax(session) recorder.register_serializer(PrettyJSONSerializer) with betamax.Betamax.configure() as config: diff --git a/tests/integration/models/reddit/test_subreddit.py b/tests/integration/models/reddit/test_subreddit.py index f916d0409..323caa46f 100644 --- a/tests/integration/models/reddit/test_subreddit.py +++ b/tests/integration/models/reddit/test_subreddit.py @@ -1,13 +1,11 @@ """Test praw.models.subreddit.""" -import socket from json import dumps from unittest import mock from unittest.mock import MagicMock +import niquests import pytest -import requests -import websocket from prawcore import BadRequest, Forbidden, NotFound, RequestException, TooLarge from praw.const import PNG_HEADER @@ -1223,6 +1221,10 @@ class WebsocketMock: def make_dict(cls, post_id): return {"payload": {"redirect": cls.POST_URL.format(post_id)}} + @property + def closed(self) -> bool: + return False + def __call__(self, *args, **kwargs): return self @@ -1233,14 +1235,27 @@ def __init__(self, *post_ids): def close(self, *args, **kwargs): pass - def recv(self): + def next_payload(self): if not self.post_ids: - raise websocket.WebSocketTimeoutException() + raise niquests.ReadTimeout assert 0 <= self.i + 1 < len(self.post_ids) self.i += 1 return dumps(self.make_dict(self.post_ids[self.i])) +class ResponseWithWebSocketExtMock: + + def __init__(self, fake_extension: WebsocketMock): + self.extension = fake_extension + + @property + def status_code(self) -> int: + return 101 + + def raise_for_status(self): + return self + + class WebsocketMockException: def __init__(self, close_exc=None, recv_exc=None): """Initialize a WebsocketMockException. @@ -1548,9 +1563,11 @@ def test_submit_gallery__flair(self, image_path, reddit): assert submission.link_flair_text == flair_text @mock.patch( - "websocket.create_connection", + "niquests.Session.get", new=MagicMock( - return_value=WebsocketMock("16xb01r", "16xb06z", "16xb0aa") + return_value=ResponseWithWebSocketExtMock( + WebsocketMock("16xb01r", "16xb06z", "16xb0aa") + ) ), # update with cassette ) def test_submit_image(self, image_path, reddit): @@ -1565,7 +1582,8 @@ def test_submit_image(self, image_path, reddit): @pytest.mark.cassette_name("TestSubreddit.test_submit_image") @mock.patch( - "websocket.create_connection", new=MagicMock(return_value=WebsocketMock()) + "niquests.Session.get", + new=MagicMock(return_value=ResponseWithWebSocketExtMock(WebsocketMock())), ) def test_submit_image__bad_websocket(self, image_path, reddit): reddit.read_only = False @@ -1576,8 +1594,10 @@ def test_submit_image__bad_websocket(self, image_path, reddit): subreddit.submit_image("Test Title", image) @mock.patch( - "websocket.create_connection", - new=MagicMock(return_value=WebsocketMock("ah3gqo")), + "niquests.Session.get", + new=MagicMock( + return_value=ResponseWithWebSocketExtMock(WebsocketMock("ah3gqo")) + ), ) # update with cassette def test_submit_image__flair(self, image_path, reddit): flair_id = "6bd28436-1aa7-11e9-9902-0e05ab0fad46" @@ -1611,7 +1631,7 @@ def test_submit_image__large(self, image_path, reddit, tmp_path): def patch_request(url, *args, **kwargs): """Patch requests to return mock data on specific url.""" if "https://reddit-uploaded-media.s3-accelerate.amazonaws.com" in url: - response = requests.Response() + response = niquests.Response() response._content = mock_data.encode("utf-8") response.encoding = "utf-8" response.status_code = 400 @@ -1628,7 +1648,7 @@ def patch_request(url, *args, **kwargs): reddit._core._requestor._http.post = _post @mock.patch( - "websocket.create_connection", new=MagicMock(side_effect=BlockingIOError) + "niquests.Session.get", new=MagicMock(side_effect=niquests.Timeout) ) # happens with timeout=0 @pytest.mark.cassette_name("TestSubreddit.test_submit_image") def test_submit_image__timeout_1(self, image_path, reddit): @@ -1638,69 +1658,6 @@ def test_submit_image__timeout_1(self, image_path, reddit): with pytest.raises(WebSocketException): subreddit.submit_image("Test Title", image) - @mock.patch( - "websocket.create_connection", - new=MagicMock( - side_effect=socket.timeout - # happens with timeout=0.00001 - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_image") - def test_submit_image__timeout_2(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - image = image_path("test.jpg") - with pytest.raises(WebSocketException): - subreddit.submit_image("Test Title", image) - - @mock.patch( - "websocket.create_connection", - new=MagicMock( - return_value=WebsocketMockException( - recv_exc=websocket.WebSocketTimeoutException() - ), # happens with timeout=0.1 - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_image") - def test_submit_image__timeout_3(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - image = image_path("test.jpg") - with pytest.raises(WebSocketException): - subreddit.submit_image("Test Title", image) - - @mock.patch( - "websocket.create_connection", - new=MagicMock( - return_value=WebsocketMockException( - close_exc=websocket.WebSocketTimeoutException() - ), # could happen, and PRAW should handle it - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_image") - def test_submit_image__timeout_4(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - image = image_path("test.jpg") - with pytest.raises(WebSocketException): - subreddit.submit_image("Test Title", image) - - @mock.patch( - "websocket.create_connection", - new=MagicMock( - return_value=WebsocketMockException( - recv_exc=websocket.WebSocketConnectionClosedException() - ), # from issue #1124 - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_image") - def test_submit_image__timeout_5(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - image = image_path("test.jpg") - with pytest.raises(WebSocketException): - subreddit.submit_image("Test Title", image) - def test_submit_image__without_websockets(self, image_path, reddit): reddit.read_only = False subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) @@ -1712,8 +1669,10 @@ def test_submit_image__without_websockets(self, image_path, reddit): assert submission is None @mock.patch( - "websocket.create_connection", - new=MagicMock(return_value=WebsocketMock("k5s3b3")), + "niquests.Session.get", + new=MagicMock( + return_value=ResponseWithWebSocketExtMock(WebsocketMock("k5s3b3")) + ), ) # update with cassette def test_submit_image_chat(self, image_path, reddit): reddit.read_only = False @@ -1785,9 +1744,11 @@ def test_submit_poll__live_chat(self, reddit): assert submission.discussion_type == "CHAT" @mock.patch( - "websocket.create_connection", + "niquests.Session.get", new=MagicMock( - return_value=WebsocketMock("k5rsq3", "k5rt9d"), # update with cassette + return_value=ResponseWithWebSocketExtMock( + WebsocketMock("k5rsq3", "k5rt9d") + ), # update with cassette ), ) def test_submit_video(self, image_path, reddit): @@ -1802,7 +1763,8 @@ def test_submit_video(self, image_path, reddit): @pytest.mark.cassette_name("TestSubreddit.test_submit_video") @mock.patch( - "websocket.create_connection", new=MagicMock(return_value=WebsocketMock()) + "niquests.Session.get", + new=MagicMock(return_value=ResponseWithWebSocketExtMock(WebsocketMock())), ) def test_submit_video__bad_websocket(self, image_path, reddit): reddit.read_only = False @@ -1813,8 +1775,10 @@ def test_submit_video__bad_websocket(self, image_path, reddit): subreddit.submit_video("Test Title", video) @mock.patch( - "websocket.create_connection", - new=MagicMock(return_value=WebsocketMock("ahells")), + "niquests.Session.get", + new=MagicMock( + return_value=ResponseWithWebSocketExtMock(WebsocketMock("ahells")) + ), ) # update with cassette def test_submit_video__flair(self, image_path, reddit): flair_id = "6bd28436-1aa7-11e9-9902-0e05ab0fad46" @@ -1830,9 +1794,9 @@ def test_submit_video__flair(self, image_path, reddit): assert submission.link_flair_text == flair_text @mock.patch( - "websocket.create_connection", + "niquests.Session.get", new=MagicMock( - return_value=WebsocketMock("k5rvt5", "k5rwbo") + return_value=ResponseWithWebSocketExtMock(WebsocketMock("k5rvt5", "k5rwbo")) ), # update with cassette ) def test_submit_video__thumbnail(self, image_path, reddit): @@ -1852,7 +1816,7 @@ def test_submit_video__thumbnail(self, image_path, reddit): assert submission.title == "Test Title" @mock.patch( - "websocket.create_connection", new=MagicMock(side_effect=BlockingIOError) + "niquests.Session.get", new=MagicMock(side_effect=niquests.Timeout) ) # happens with timeout=0 @pytest.mark.cassette_name("TestSubreddit.test_submit_video") def test_submit_video__timeout_1(self, image_path, reddit): @@ -1863,72 +1827,11 @@ def test_submit_video__timeout_1(self, image_path, reddit): subreddit.submit_video("Test Title", video) @mock.patch( - "websocket.create_connection", - new=MagicMock( - side_effect=socket.timeout - # happens with timeout=0.00001 - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_video") - def test_submit_video__timeout_2(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - video = image_path("test.mov") - with pytest.raises(WebSocketException): - subreddit.submit_video("Test Title", video) - - @mock.patch( - "websocket.create_connection", - new=MagicMock( - return_value=WebsocketMockException( - recv_exc=websocket.WebSocketTimeoutException() - ), # happens with timeout=0.1 - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_video") - def test_submit_video__timeout_3(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - video = image_path("test.mov") - with pytest.raises(WebSocketException): - subreddit.submit_video("Test Title", video) - - @mock.patch( - "websocket.create_connection", - new=MagicMock( - return_value=WebsocketMockException( - close_exc=websocket.WebSocketTimeoutException() - ), # could happen, and PRAW should handle it - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_video") - def test_submit_video__timeout_4(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - video = image_path("test.mov") - with pytest.raises(WebSocketException): - subreddit.submit_video("Test Title", video) - - @mock.patch( - "websocket.create_connection", + "niquests.Session.get", new=MagicMock( - return_value=WebsocketMockException( - close_exc=websocket.WebSocketConnectionClosedException() - ), # from issue #1124 - ), - ) - @pytest.mark.cassette_name("TestSubreddit.test_submit_video") - def test_submit_video__timeout_5(self, image_path, reddit): - reddit.read_only = False - subreddit = reddit.subreddit(pytest.placeholders.test_subreddit) - video = image_path("test.mov") - with pytest.raises(WebSocketException): - subreddit.submit_video("Test Title", video) - - @mock.patch( - "websocket.create_connection", - new=MagicMock( - return_value=WebsocketMock("k5s10u", "k5s11v"), # update with cassette + return_value=ResponseWithWebSocketExtMock( + WebsocketMock("k5s10u", "k5s11v") + ), # update with cassette ), ) def test_submit_video__videogif(self, image_path, reddit): @@ -1952,8 +1855,10 @@ def test_submit_video__without_websockets(self, image_path, reddit): assert submission is None @mock.patch( - "websocket.create_connection", - new=MagicMock(return_value=WebsocketMock("flnyhf")), + "niquests.Session.get", + new=MagicMock( + return_value=ResponseWithWebSocketExtMock(WebsocketMock("flnyhf")) + ), ) # update with cassette def test_submit_video_chat(self, image_path, reddit): reddit.read_only = False diff --git a/tests/unit/models/reddit/test_subreddit.py b/tests/unit/models/reddit/test_subreddit.py index 2259ab651..46e392e7c 100644 --- a/tests/unit/models/reddit/test_subreddit.py +++ b/tests/unit/models/reddit/test_subreddit.py @@ -1,6 +1,7 @@ import json import pickle from unittest import mock +from unittest.mock import MagicMock import pytest @@ -53,7 +54,7 @@ def test_hash(self, reddit): assert hash(subreddit2) != hash(subreddit3) assert hash(subreddit1) != hash(subreddit3) - @mock.patch("websocket.create_connection") + @mock.patch("niquests.Session.get") @mock.patch( "praw.models.Subreddit._upload_media", return_value=("fake_media_url", "fake_websocket_url"), @@ -64,14 +65,23 @@ def test_hash(self, reddit): def test_invalid_media( self, _mock_post, _mock_upload_media, connection_mock, reddit ): - connection_mock().recv.return_value = json.dumps( - {"payload": {}, "type": "failed"} + connection_mock.return_value = MagicMock( + status_code=101, + raise_for_status=MagicMock( + return_value=MagicMock( + extension=MagicMock( + next_payload=MagicMock( + return_value=json.dumps({"payload": {}, "type": "failed"}) + ) + ), + ) + ), ) with pytest.raises(MediaPostFailed): reddit.subreddit("test").submit_image("Test", "dummy path") @mock.patch("praw.models.Subreddit._read_and_post_media") - @mock.patch("websocket.create_connection") + @mock.patch("niquests.Session.get") @mock.patch( "praw.Reddit.post", return_value={ diff --git a/tests/unit/test_reddit.py b/tests/unit/test_reddit.py index 103edc4db..9f2cf9cf3 100644 --- a/tests/unit/test_reddit.py +++ b/tests/unit/test_reddit.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock import pytest -import requests +import niquests from prawcore import Requestor from prawcore.exceptions import BadRequest @@ -37,7 +37,7 @@ async def check_async(reddit): @staticmethod def patch_request(*args, **kwargs): """Patch requests to return mock data on specific url.""" - response = requests.Response() + response = niquests.Response() response._content = '{"name":"username"}'.encode("utf-8") response.status_code = 200 return response