Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made interceptors runnable for any number of requests on a single Session #1038

Merged
merged 4 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions cpr/multiperform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,25 @@
#include "cpr/response.h"
#include "cpr/session.h"
#include <algorithm>
#include <cassert>
#include <cstddef>
#include <curl/curl.h>
#include <curl/multi.h>
#include <functional>
#include <iosfwd>
#include <iostream>
#include <memory>
#include <optional>
#include <stdexcept>
#include <utility>
#include <vector>

namespace cpr {

MultiPerform::MultiPerform() : multicurl_(new CurlMultiHolder()) {}
MultiPerform::MultiPerform() : multicurl_(new CurlMultiHolder()) {
current_interceptor_ = interceptors_.end();
first_interceptor_ = interceptors_.end();
}

MultiPerform::~MultiPerform() {
// Unlock all sessions
Expand Down Expand Up @@ -154,17 +159,19 @@ std::vector<Response> MultiPerform::ReadMultiInfo(const std::function<Response(S
}

std::vector<Response> MultiPerform::MakeRequest() {
if (!interceptors_.empty()) {
return intercept();
const std::optional<std::vector<Response>> r = intercept();
if (r.has_value()) {
return r.value();
}

DoMultiPerform();
return ReadMultiInfo([](Session& session, CURLcode curl_error) -> Response { return session.Complete(curl_error); });
}

std::vector<Response> MultiPerform::MakeDownloadRequest() {
if (!interceptors_.empty()) {
return intercept();
const std::optional<std::vector<Response>> r = intercept();
if (r.has_value()) {
return r.value();
}

DoMultiPerform();
Expand Down Expand Up @@ -325,15 +332,33 @@ std::vector<Response> MultiPerform::proceed() {
return MakeRequest();
}

std::vector<Response> MultiPerform::intercept() {
// At least one interceptor exists -> Execute its intercept function
const std::shared_ptr<InterceptorMulti> interceptor = interceptors_.front();
interceptors_.pop();
return interceptor->intercept(*this);
const std::optional<std::vector<Response>> MultiPerform::intercept() {
if (current_interceptor_ == interceptors_.end()) {
current_interceptor_ = first_interceptor_;
} else {
current_interceptor_++;
}

if (current_interceptor_ != interceptors_.end()) {
auto icpt = current_interceptor_;
// Nested makeRequest() start at first_interceptor_, thus excluding previous interceptors.
first_interceptor_ = current_interceptor_;
++first_interceptor_;

const std::optional<std::vector<Response>> r = (*current_interceptor_)->intercept(*this);

first_interceptor_ = icpt;

return r;
}
return std::nullopt;
}

void MultiPerform::AddInterceptor(const std::shared_ptr<InterceptorMulti>& pinterceptor) {
interceptors_.push(pinterceptor);
// Shall only add before first interceptor run
assert(current_interceptor_ == interceptors_.end());
interceptors_.push_back(pinterceptor);
first_interceptor_ = interceptors_.begin();
}

} // namespace cpr
44 changes: 34 additions & 10 deletions cpr/session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <fstream>
#include <iostream>
#include <memory>
#include <optional>
#include <stdexcept>
#include <string>
#include <utility>
Expand Down Expand Up @@ -139,11 +140,14 @@ Session::Session() : curl_(new CurlHolder()) {
#if LIBCURL_VERSION_NUM >= 0x071900 // 7.25.0
curl_easy_setopt(curl_->handle, CURLOPT_TCP_KEEPALIVE, 1L);
#endif
current_interceptor_ = interceptors_.end();
first_interceptor_ = interceptors_.end();
}

Response Session::makeDownloadRequest() {
if (!interceptors_.empty()) {
return intercept();
const std::optional<Response> r = intercept();
if (r.has_value()) {
return r.value();
}

const CURLcode curl_error = DoEasyPerform();
Expand Down Expand Up @@ -262,11 +266,13 @@ void Session::prepareCommonDownload() {
}

Response Session::makeRequest() {
if (!interceptors_.empty()) {
return intercept();
const std::optional<Response> r = intercept();
if (r.has_value()) {
return r.value();
}

const CURLcode curl_error = DoEasyPerform();

return Complete(curl_error);
}

Expand Down Expand Up @@ -880,19 +886,37 @@ Response Session::CompleteDownload(CURLcode curl_error) {
}

void Session::AddInterceptor(const std::shared_ptr<Interceptor>& pinterceptor) {
interceptors_.push(pinterceptor);
// Shall only add before first interceptor run
assert(current_interceptor_ == interceptors_.end());
interceptors_.push_back(pinterceptor);
first_interceptor_ = interceptors_.begin();
}

Response Session::proceed() {
prepareCommon();
return makeRequest();
}

Response Session::intercept() {
// At least one interceptor exists -> Execute its intercept function
const std::shared_ptr<Interceptor> interceptor = interceptors_.front();
interceptors_.pop();
return interceptor->intercept(*this);
const std::optional<Response> Session::intercept() {
if (current_interceptor_ == interceptors_.end()) {
current_interceptor_ = first_interceptor_;
} else {
current_interceptor_++;
}

if (current_interceptor_ != interceptors_.end()) {
auto icpt = current_interceptor_;
// Nested makeRequest() start at first_interceptor_, thus excluding previous interceptors.
first_interceptor_ = current_interceptor_;
++first_interceptor_;

const std::optional<Response> r = (*current_interceptor_)->intercept(*this);

first_interceptor_ = icpt;

return r;
}
return std::nullopt;
}

void Session::prepareBodyPayloadOrMultipart() const {
Expand Down
9 changes: 7 additions & 2 deletions include/cpr/multiperform.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class MultiPerform {
template <typename... DownloadArgTypes>
void PrepareDownload(DownloadArgTypes... args);

std::vector<Response> intercept();
const std::optional<std::vector<Response>> intercept();
std::vector<Response> proceed();
std::vector<Response> MakeRequest();
std::vector<Response> MakeDownloadRequest();
Expand All @@ -93,7 +93,12 @@ class MultiPerform {
std::unique_ptr<CurlMultiHolder> multicurl_;
bool is_download_multi_perform{false};

std::queue<std::shared_ptr<InterceptorMulti>> interceptors_;
using InterceptorsContainer = std::list<std::shared_ptr<InterceptorMulti>>;
InterceptorsContainer interceptors_;
// Currently running interceptor
InterceptorsContainer::iterator current_interceptor_;
// Interceptor within the chain where to start with each repeated request
InterceptorsContainer::iterator first_interceptor_;
};

template <typename CurrentDownloadArgType>
Expand Down
12 changes: 9 additions & 3 deletions include/cpr/session.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@
#include <fstream>
#include <functional>
#include <future>
#include <list>
#include <memory>
#include <optional>
#include <queue>
#include <variant>

#include "cpr/accept_encoding.h"
Expand Down Expand Up @@ -261,14 +261,20 @@ class Session : public std::enable_shared_from_this<Session> {
size_t response_string_reserve_size_{0};
std::string response_string_;
std::string header_string_;
std::queue<std::shared_ptr<Interceptor>> interceptors_;
// Container type is required to keep iterator valid on elem insertion. E.g. list but not vector.
using InterceptorsContainer = std::list<std::shared_ptr<Interceptor>>;
COM8 marked this conversation as resolved.
Show resolved Hide resolved
COM8 marked this conversation as resolved.
Show resolved Hide resolved
InterceptorsContainer interceptors_;
// Currently running interceptor
InterceptorsContainer::const_iterator current_interceptor_;
// Interceptor within the chain where to start with each repeated request
InterceptorsContainer::const_iterator first_interceptor_;
bool isUsedInMultiPerform{false};
bool isCancellable{false};

Response makeDownloadRequest();
Response makeRequest();
Response proceed();
Response intercept();
const std::optional<Response> intercept();
/**
* Prepares the curl object for a request with everything used by all requests.
**/
Expand Down
41 changes: 40 additions & 1 deletion test/interceptor_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,19 @@ class ChangeRequestMethodToPatchInterceptor : public Interceptor {
}
};

class RetryInterceptor : public Interceptor {
public:
Response intercept(Session& session) override {
// Proceed the chain
Response response = proceed(session);

// retried request
response = proceed(session);

return response;
}
};

TEST(InterceptorTest, HiddenUrlRewriteInterceptorTest) {
Url url{server->GetBaseUrl() + "/basic.json"};
Session session;
Expand All @@ -151,6 +164,32 @@ TEST(InterceptorTest, ChangeStatusCodeInterceptorTest) {
EXPECT_EQ(url, response.url);
EXPECT_EQ(expected_status_code, response.status_code);
EXPECT_EQ(ErrorCode::OK, response.error.code);

// second request
response = session.Get();
EXPECT_EQ(url, response.url);
EXPECT_EQ(expected_status_code, response.status_code);
EXPECT_EQ(ErrorCode::OK, response.error.code);
}

TEST(InterceptorTest, RetryInterceptorTest) {
Url url{server->GetBaseUrl() + "/hello.html"};
Session session;
session.SetUrl(url);
session.AddInterceptor(std::make_shared<RetryInterceptor>());
session.AddInterceptor(std::make_shared<ChangeStatusCodeInterceptor>());
Response response = session.Get();

long expected_status_code{12345};
EXPECT_EQ(url, response.url);
EXPECT_EQ(expected_status_code, response.status_code);
EXPECT_EQ(ErrorCode::OK, response.error.code);

// second request
response = session.Get();
EXPECT_EQ(url, response.url);
EXPECT_EQ(expected_status_code, response.status_code);
EXPECT_EQ(ErrorCode::OK, response.error.code);
}

TEST(InterceptorTest, DownloadChangeStatusCodeInterceptorTest) {
Expand Down Expand Up @@ -366,4 +405,4 @@ int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
::testing::AddGlobalTestEnvironment(server);
return RUN_ALL_TESTS();
}
}
Loading