diff --git a/cpr/multiperform.cpp b/cpr/multiperform.cpp index 27b246316..27c40ef30 100644 --- a/cpr/multiperform.cpp +++ b/cpr/multiperform.cpp @@ -6,6 +6,7 @@ #include "cpr/response.h" #include "cpr/session.h" #include +#include #include #include #include @@ -13,13 +14,17 @@ #include #include #include +#include #include #include #include namespace cpr { -MultiPerform::MultiPerform() : multicurl_(new CurlMultiHolder()) {} +MultiPerform::MultiPerform() : multicurl_(new CurlMultiHolder()) { + current_interceptor_ = interceptors_.end(); + first_interceptor_ = interceptors_.end(); +} MultiPerform::~MultiPerform() { // Unlock all sessions @@ -154,8 +159,9 @@ std::vector MultiPerform::ReadMultiInfo(const std::function MultiPerform::MakeRequest() { - if (!interceptors_.empty()) { - return intercept(); + const std::optional> r = intercept(); + if (r.has_value()) { + return r.value(); } DoMultiPerform(); @@ -163,8 +169,9 @@ std::vector MultiPerform::MakeRequest() { } std::vector MultiPerform::MakeDownloadRequest() { - if (!interceptors_.empty()) { - return intercept(); + const std::optional> r = intercept(); + if (r.has_value()) { + return r.value(); } DoMultiPerform(); @@ -325,15 +332,33 @@ std::vector MultiPerform::proceed() { return MakeRequest(); } -std::vector MultiPerform::intercept() { - // At least one interceptor exists -> Execute its intercept function - const std::shared_ptr interceptor = interceptors_.front(); - interceptors_.pop(); - return interceptor->intercept(*this); +const std::optional> MultiPerform::intercept() { + if (current_interceptor_ == interceptors_.end()) { + current_interceptor_ = first_interceptor_; + } else { + current_interceptor_++; + } + + if (current_interceptor_ != interceptors_.end()) { + auto icpt = current_interceptor_; + // Nested makeRequest() start at first_interceptor_, thus excluding previous interceptors. + first_interceptor_ = current_interceptor_; + ++first_interceptor_; + + const std::optional> r = (*current_interceptor_)->intercept(*this); + + first_interceptor_ = icpt; + + return r; + } + return std::nullopt; } void MultiPerform::AddInterceptor(const std::shared_ptr& pinterceptor) { - interceptors_.push(pinterceptor); + // Shall only add before first interceptor run + assert(current_interceptor_ == interceptors_.end()); + interceptors_.push_back(pinterceptor); + first_interceptor_ = interceptors_.begin(); } } // namespace cpr diff --git a/cpr/session.cpp b/cpr/session.cpp index 97eb93a71..914d0bff1 100644 --- a/cpr/session.cpp +++ b/cpr/session.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -139,11 +140,14 @@ Session::Session() : curl_(new CurlHolder()) { #if LIBCURL_VERSION_NUM >= 0x071900 // 7.25.0 curl_easy_setopt(curl_->handle, CURLOPT_TCP_KEEPALIVE, 1L); #endif + current_interceptor_ = interceptors_.end(); + first_interceptor_ = interceptors_.end(); } Response Session::makeDownloadRequest() { - if (!interceptors_.empty()) { - return intercept(); + const std::optional r = intercept(); + if (r.has_value()) { + return r.value(); } const CURLcode curl_error = DoEasyPerform(); @@ -262,11 +266,13 @@ void Session::prepareCommonDownload() { } Response Session::makeRequest() { - if (!interceptors_.empty()) { - return intercept(); + const std::optional r = intercept(); + if (r.has_value()) { + return r.value(); } const CURLcode curl_error = DoEasyPerform(); + return Complete(curl_error); } @@ -880,7 +886,10 @@ Response Session::CompleteDownload(CURLcode curl_error) { } void Session::AddInterceptor(const std::shared_ptr& pinterceptor) { - interceptors_.push(pinterceptor); + // Shall only add before first interceptor run + assert(current_interceptor_ == interceptors_.end()); + interceptors_.push_back(pinterceptor); + first_interceptor_ = interceptors_.begin(); } Response Session::proceed() { @@ -888,11 +897,26 @@ Response Session::proceed() { return makeRequest(); } -Response Session::intercept() { - // At least one interceptor exists -> Execute its intercept function - const std::shared_ptr interceptor = interceptors_.front(); - interceptors_.pop(); - return interceptor->intercept(*this); +const std::optional Session::intercept() { + if (current_interceptor_ == interceptors_.end()) { + current_interceptor_ = first_interceptor_; + } else { + current_interceptor_++; + } + + if (current_interceptor_ != interceptors_.end()) { + auto icpt = current_interceptor_; + // Nested makeRequest() start at first_interceptor_, thus excluding previous interceptors. + first_interceptor_ = current_interceptor_; + ++first_interceptor_; + + const std::optional r = (*current_interceptor_)->intercept(*this); + + first_interceptor_ = icpt; + + return r; + } + return std::nullopt; } void Session::prepareBodyPayloadOrMultipart() const { diff --git a/include/cpr/multiperform.h b/include/cpr/multiperform.h index 14b9b1c01..d4ac445ee 100644 --- a/include/cpr/multiperform.h +++ b/include/cpr/multiperform.h @@ -81,7 +81,7 @@ class MultiPerform { template void PrepareDownload(DownloadArgTypes... args); - std::vector intercept(); + const std::optional> intercept(); std::vector proceed(); std::vector MakeRequest(); std::vector MakeDownloadRequest(); @@ -93,7 +93,12 @@ class MultiPerform { std::unique_ptr multicurl_; bool is_download_multi_perform{false}; - std::queue> interceptors_; + using InterceptorsContainer = std::list>; + InterceptorsContainer interceptors_; + // Currently running interceptor + InterceptorsContainer::iterator current_interceptor_; + // Interceptor within the chain where to start with each repeated request + InterceptorsContainer::iterator first_interceptor_; }; template diff --git a/include/cpr/session.h b/include/cpr/session.h index 0ff13f847..df1324707 100644 --- a/include/cpr/session.h +++ b/include/cpr/session.h @@ -5,9 +5,9 @@ #include #include #include +#include #include #include -#include #include #include "cpr/accept_encoding.h" @@ -261,14 +261,20 @@ class Session : public std::enable_shared_from_this { size_t response_string_reserve_size_{0}; std::string response_string_; std::string header_string_; - std::queue> interceptors_; + // Container type is required to keep iterator valid on elem insertion. E.g. list but not vector. + using InterceptorsContainer = std::list>; + InterceptorsContainer interceptors_; + // Currently running interceptor + InterceptorsContainer::const_iterator current_interceptor_; + // Interceptor within the chain where to start with each repeated request + InterceptorsContainer::const_iterator first_interceptor_; bool isUsedInMultiPerform{false}; bool isCancellable{false}; Response makeDownloadRequest(); Response makeRequest(); Response proceed(); - Response intercept(); + const std::optional intercept(); /** * Prepares the curl object for a request with everything used by all requests. **/ diff --git a/test/interceptor_tests.cpp b/test/interceptor_tests.cpp index 8bdc06b25..1732a3944 100644 --- a/test/interceptor_tests.cpp +++ b/test/interceptor_tests.cpp @@ -127,6 +127,19 @@ class ChangeRequestMethodToPatchInterceptor : public Interceptor { } }; +class RetryInterceptor : public Interceptor { + public: + Response intercept(Session& session) override { + // Proceed the chain + Response response = proceed(session); + + // retried request + response = proceed(session); + + return response; + } +}; + TEST(InterceptorTest, HiddenUrlRewriteInterceptorTest) { Url url{server->GetBaseUrl() + "/basic.json"}; Session session; @@ -151,6 +164,32 @@ TEST(InterceptorTest, ChangeStatusCodeInterceptorTest) { EXPECT_EQ(url, response.url); EXPECT_EQ(expected_status_code, response.status_code); EXPECT_EQ(ErrorCode::OK, response.error.code); + + // second request + response = session.Get(); + EXPECT_EQ(url, response.url); + EXPECT_EQ(expected_status_code, response.status_code); + EXPECT_EQ(ErrorCode::OK, response.error.code); +} + +TEST(InterceptorTest, RetryInterceptorTest) { + Url url{server->GetBaseUrl() + "/hello.html"}; + Session session; + session.SetUrl(url); + session.AddInterceptor(std::make_shared()); + session.AddInterceptor(std::make_shared()); + Response response = session.Get(); + + long expected_status_code{12345}; + EXPECT_EQ(url, response.url); + EXPECT_EQ(expected_status_code, response.status_code); + EXPECT_EQ(ErrorCode::OK, response.error.code); + + // second request + response = session.Get(); + EXPECT_EQ(url, response.url); + EXPECT_EQ(expected_status_code, response.status_code); + EXPECT_EQ(ErrorCode::OK, response.error.code); } TEST(InterceptorTest, DownloadChangeStatusCodeInterceptorTest) { @@ -366,4 +405,4 @@ int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); ::testing::AddGlobalTestEnvironment(server); return RUN_ALL_TESTS(); -} \ No newline at end of file +}