From c2f32c49e8e4976bf163e5b6793a79bb3c41a67c Mon Sep 17 00:00:00 2001 From: ValueRaider Date: Sat, 29 Jun 2024 21:10:42 +0100 Subject: [PATCH] Fix pickling --- requests_ratelimiter/requests_ratelimiter.py | 9 +++++++ test/test_requests_ratelimiter.py | 27 ++++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/requests_ratelimiter/requests_ratelimiter.py b/requests_ratelimiter/requests_ratelimiter.py index e8345a3..6dfbacc 100644 --- a/requests_ratelimiter/requests_ratelimiter.py +++ b/requests_ratelimiter/requests_ratelimiter.py @@ -168,6 +168,15 @@ class LimiterSession(LimiterMixin, Session): per_host: Track request rate limits separately for each host limit_statuses: Alternative HTTP status codes that indicate a rate limit was exceeded """ + + __attrs__ = Session.__attrs__ + [ + 'limiter', + 'limit_statuses', + 'max_delay', + 'per_host', + 'bucket_name', + '_default_bucket' + ] class LimiterAdapter(LimiterMixin, HTTPAdapter): # type: ignore # send signature accepts **kwargs diff --git a/test/test_requests_ratelimiter.py b/test/test_requests_ratelimiter.py index b7ab45b..b0050e7 100644 --- a/test/test_requests_ratelimiter.py +++ b/test/test_requests_ratelimiter.py @@ -14,6 +14,8 @@ from time import sleep from unittest.mock import patch +import pickle + import pytest from pyrate_limiter import Duration, Limiter, RequestRate, SQLiteBucket from requests import Response, Session @@ -244,3 +246,28 @@ class CachedLimiterSession(CacheMixin, LimiterMixin, Session): for _ in range(10): session.get(MOCKED_URL) assert mock_sleep.called is False + + +def test_inherited_session_attributes(): + # Test that inherited Session attributes are preserved + session = LimiterSession(per_second=5) + assert hasattr(session, 'headers') + assert hasattr(session, 'cookies') + assert hasattr(session, 'auth') + assert hasattr(session, 'hooks') + + +def test_pickling_and_unpickling(): + # Test pickling and unpickling of LimiterSession instance + session = LimiterSession(per_second=5) + pickled_session = pickle.dumps(session) + assert pickled_session is not None + unpickled_session = pickle.loads(pickled_session) + assert unpickled_session is not None + + # Check that the unpickled instance has the same attributes + assert unpickled_session.per_host == session.per_host + assert unpickled_session.max_delay == session.max_delay + assert unpickled_session.bucket_name == session.bucket_name + assert unpickled_session.limit_statuses == session.limit_statuses + assert unpickled_session._default_bucket == session._default_bucket