diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 95d7fd67..cb82fa75 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -16,13 +16,6 @@ find_package(Boost REQUIRED) add_subdirectory(third_party) -add_library(brad_server_lib OBJECT - server/brad_server_simple.cc - server/brad_sql_info.cc - server/brad_statement_batch_reader.cc - server/brad_statement.cc - server/brad_tables_schema_batch_reader.cc) - add_library(sqlite_server_lib OBJECT sqlite_server/sqlite_server.cc sqlite_server/sqlite_sql_info.cc @@ -31,12 +24,18 @@ add_library(sqlite_server_lib OBJECT sqlite_server/sqlite_tables_schema_batch_reader.cc sqlite_server/sqlite_type_info.cc) -pybind11_add_module(pybind_brad_server pybind/brad_server.cc) +pybind11_add_module(pybind_brad_server pybind/brad_server.cc + server/brad_server_simple.cc + server/brad_sql_info.cc + server/brad_statement_batch_reader.cc + server/brad_statement.cc + server/brad_tables_schema_batch_reader.cc) + target_link_libraries(pybind_brad_server PRIVATE Arrow::arrow_shared PRIVATE ArrowFlight::arrow_flight_shared PRIVATE ArrowFlightSql::arrow_flight_sql_shared - PRIVATE brad_server_lib) + PUBLIC libcuckoo) add_executable(flight_sql_example_client flight_sql_example_client.cc) target_link_libraries(flight_sql_example_client @@ -55,14 +54,6 @@ target_link_libraries(flight_sql_example_server ${SQLite3_LIBRARIES} ${Boost_LIBRARIES}) -add_executable(flight_sql_brad_server flight_sql_brad_server.cc) -target_link_libraries(flight_sql_brad_server - PRIVATE Arrow::arrow_shared - PRIVATE ArrowFlight::arrow_flight_shared - PRIVATE ArrowFlightSql::arrow_flight_sql_shared - PRIVATE brad_server_lib - gflags) - add_executable(brad_front_end brad_front_end.cc) target_link_libraries(brad_front_end PRIVATE Arrow::arrow_shared diff --git a/cpp/pybind/brad_server.cc b/cpp/pybind/brad_server.cc index be006996..3c8c2ac5 100644 --- a/cpp/pybind/brad_server.cc +++ b/cpp/pybind/brad_server.cc @@ -1,4 +1,6 @@ #include +#include +#include #include diff --git a/cpp/server/brad_server_simple.cc b/cpp/server/brad_server_simple.cc index 19321e76..cb090030 100644 --- a/cpp/server/brad_server_simple.cc +++ b/cpp/server/brad_server_simple.cc @@ -22,14 +22,16 @@ using arrow::internal::checked_cast; using namespace arrow::flight; using namespace arrow::flight::sql; -arrow::Result EncodeTransactionQuery( - const std::string &query, +std::string GetQueryTicket( + const std::string &autoincrement_id, const std::string &transaction_id) { - std::string transaction_query = transaction_id; - transaction_query += ':'; - transaction_query += query; + return transaction_id + ':' + autoincrement_id; +} + +arrow::Result EncodeTransactionQuery( + const std::string &query_ticket) { ARROW_ASSIGN_OR_RAISE(auto ticket_string, - CreateStatementQueryTicket(transaction_query)); + CreateStatementQueryTicket(query_ticket)); return Ticket{std::move(ticket_string)}; } @@ -40,17 +42,35 @@ arrow::Result> DecodeTransactionQuery( return arrow::Status::Invalid("Malformed ticket"); } std::string transaction_id = ticket.substr(0, divider); - std::string query = ticket.substr(divider + 1); - return std::make_pair(std::move(query), std::move(transaction_id)); + std::string autoincrement_id = ticket.substr(divider + 1); + return std::make_pair(std::move(autoincrement_id), std::move(transaction_id)); } -BradFlightSqlServer::BradFlightSqlServer() = default; +std::vector> TransformQueryResult( + std::vector query_result) { + std::vector> transformed_query_result; + for (const auto &row : query_result) { + std::vector transformed_row{}; + for (const auto &field : row) { + if (py::isinstance(field)) { + transformed_row.push_back(std::make_any(py::cast(field))); + } else if (py::isinstance(field)) { + transformed_row.push_back(std::make_any(py::cast(field))); + } else { + transformed_row.push_back(std::make_any(py::cast(field))); + } + } + transformed_query_result.push_back(transformed_row); + } + return transformed_query_result; +} + +BradFlightSqlServer::BradFlightSqlServer() : autoincrement_id_(0ULL) {} BradFlightSqlServer::~BradFlightSqlServer() = default; std::shared_ptr BradFlightSqlServer::Create() { - // std::shared_ptr result(new BradFlightSqlServer()); std::shared_ptr result = std::make_shared(); for (const auto &id_to_result : GetSqlInfoResultMap()) { @@ -59,9 +79,15 @@ std::shared_ptr return result; } -void BradFlightSqlServer::InitWrapper(const std::string &host, int port) { +void BradFlightSqlServer::InitWrapper( + const std::string &host, + int port, + std::function(std::string)> handle_query) { auto location = arrow::flight::Location::ForGrpcTcp(host, port).ValueOrDie(); arrow::flight::FlightServerOptions options(location); + + handle_query_ = handle_query; + this->Init(options); } @@ -79,10 +105,25 @@ arrow::Result> const StatementQuery &command, const FlightDescriptor &descriptor) { const std::string &query = command.query; - ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(query)); - ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); + + const std::string &autoincrement_id = std::to_string(++autoincrement_id_); + const std::string &query_ticket = GetQueryTicket(autoincrement_id, command.transaction_id); ARROW_ASSIGN_OR_RAISE(auto ticket, - EncodeTransactionQuery(query, command.transaction_id)); + EncodeTransactionQuery(query_ticket)); + + std::vector> transformed_query_result; + + { + py::gil_scoped_acquire guard; + std::vector query_result = handle_query_(query); + transformed_query_result = TransformQueryResult(query_result); + } + + ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(transformed_query_result)); + query_data_.insert(query_ticket, statement); + + ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); + std::vector endpoints{ FlightEndpoint{std::move(ticket), {}, std::nullopt, ""}}; @@ -103,14 +144,23 @@ arrow::Result> const StatementQueryTicket &command) { ARROW_ASSIGN_OR_RAISE(auto pair, DecodeTransactionQuery(command.statement_handle)); - const std::string &sql = pair.first; + const std::string &autoincrement_id = pair.first; const std::string transaction_id = pair.second; - std::shared_ptr statement; - ARROW_ASSIGN_OR_RAISE(statement, BradStatement::Create(sql)); + const std::string &query_ticket = transaction_id + ':' + autoincrement_id; + + std::shared_ptr result; + const bool found = query_data_.erase_fn(query_ticket, [&result](auto& qr) { + result = qr; + return true; + }); + + if (!found) { + return arrow::Status::Invalid("Invalid ticket."); + } std::shared_ptr reader; - ARROW_ASSIGN_OR_RAISE(reader, BradStatementBatchReader::Create(statement)); + ARROW_ASSIGN_OR_RAISE(reader, BradStatementBatchReader::Create(result)); return std::make_unique(reader); } diff --git a/cpp/server/brad_server_simple.h b/cpp/server/brad_server_simple.h index f6db4cbf..d2e0c186 100644 --- a/cpp/server/brad_server_simple.h +++ b/cpp/server/brad_server_simple.h @@ -1,12 +1,23 @@ #pragma once +#include #include +#include #include #include +#include #include +#include "brad_statement.h" #include +#include "libcuckoo/cuckoohash_map.hh" + +#include + +namespace py = pybind11; +using namespace pybind11::literals; + namespace brad { class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { @@ -17,7 +28,9 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { static std::shared_ptr Create(); - void InitWrapper(const std::string &host, int port); + void InitWrapper(const std::string &host, + int port, + std::function(std::string)>); void ServeWrapper(); @@ -33,6 +46,13 @@ class BradFlightSqlServer : public arrow::flight::sql::FlightSqlServerBase { DoGetStatement( const arrow::flight::ServerCallContext &context, const arrow::flight::sql::StatementQueryTicket &command) override; + + private: + std::function(std::string)> handle_query_; + + libcuckoo::cuckoohash_map> query_data_; + + std::atomic autoincrement_id_; }; } // namespace brad diff --git a/cpp/server/brad_statement.cc b/cpp/server/brad_statement.cc index 7f791c5a..e9ce1588 100644 --- a/cpp/server/brad_statement.cc +++ b/cpp/server/brad_statement.cc @@ -24,47 +24,97 @@ arrow::Result> BradStatement::Create( return result; } +arrow::Result> BradStatement::Create( + std::vector> query_result) { + std::shared_ptr result( + std::make_shared(query_result)); + return result; +} + +BradStatement::BradStatement(std::vector> query_result) : + query_result_(std::move(query_result)) {} + BradStatement::~BradStatement() { } arrow::Result> BradStatement::GetSchema() const { + if (schema_) { + return schema_; + } + std::vector> fields; - fields.push_back(arrow::field("Day", arrow::int8())); - fields.push_back(arrow::field("Month", arrow::int8())); - fields.push_back(arrow::field("Year", arrow::int16())); - return arrow::schema(fields); + + if (query_result_.size() > 0) { + const std::vector &row = query_result_[0]; + + int counter = 0; + for (const auto &field : row) { + std::string field_type = field.type().name(); + if (field_type == "i") { + fields.push_back(arrow::field("INT FIELD " + std::to_string(++counter), arrow::int8())); + } else if (field_type == "f") { + fields.push_back(arrow::field("FLOAT FIELD " + std::to_string(++counter), arrow::float32())); + } else { + fields.push_back(arrow::field("STRING FIELD " + std::to_string(++counter), arrow::utf8())); + } + } + } + + schema_ = arrow::schema(fields); + return schema_; } arrow::Result> BradStatement::FetchResult() { - arrow::Int8Builder int8builder; - int8_t days_raw[5] = {1, 12, 17, 23, 28}; - ARROW_RETURN_NOT_OK(int8builder.AppendValues(days_raw, 5)); - std::shared_ptr days; - ARROW_ASSIGN_OR_RAISE(days, int8builder.Finish()); - - int8_t months_raw[5] = {1, 3, 5, 7, 1}; - ARROW_RETURN_NOT_OK(int8builder.AppendValues(months_raw, 5)); - std::shared_ptr months; - ARROW_ASSIGN_OR_RAISE(months, int8builder.Finish()); - - arrow::Int16Builder int16builder; - int16_t years_raw[5] = {1990, 2000, 1995, 2000, 1995}; - ARROW_RETURN_NOT_OK(int16builder.AppendValues(years_raw, 5)); - std::shared_ptr years; - ARROW_ASSIGN_OR_RAISE(years, int16builder.Finish()); - - std::shared_ptr record_batch; - - arrow::Result> result = GetSchema(); - if (result.ok()) { - std::shared_ptr schema = result.ValueOrDie(); - record_batch = arrow::RecordBatch::Make(schema, - days->length(), - {days, months, years}); - return record_batch; + std::shared_ptr schema = GetSchema().ValueOrDie(); + + const int num_rows = query_result_.size(); + + std::vector> columns; + columns.reserve(schema->num_fields()); + + for (int field_ix = 0; field_ix < schema->num_fields(); ++field_ix) { + const auto &field = schema->fields()[field_ix]; + if (field->type() == arrow::int8()) { + arrow::Int8Builder int8builder; + int8_t values_raw[num_rows]; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + values_raw[row_ix] = std::any_cast(query_result_[row_ix][field_ix]); + } + ARROW_RETURN_NOT_OK(int8builder.AppendValues(values_raw, num_rows)); + + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, int8builder.Finish()); + + columns.push_back(values); + } else if (field->type() == arrow::float32()) { + arrow::FloatBuilder floatbuilder; + float values_raw[num_rows]; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + values_raw[row_ix] = std::any_cast(query_result_[row_ix][field_ix]); + } + ARROW_RETURN_NOT_OK(floatbuilder.AppendValues(values_raw, num_rows)); + + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish()); + + columns.push_back(values); + } else if (field->type() == arrow::utf8()) { + arrow::StringBuilder stringbuilder; + for (int row_ix = 0; row_ix < num_rows; ++row_ix) { + const std::string* str = std::any_cast(&(query_result_[row_ix][field_ix])); + ARROW_RETURN_NOT_OK(stringbuilder.Append(str->data(), str->size())); + } + + std::shared_ptr values; + ARROW_ASSIGN_OR_RAISE(values, stringbuilder.Finish()); + } } - return arrow::Status::OK(); + std::shared_ptr record_batch = + arrow::RecordBatch::Make(schema, + num_rows, + columns); + return record_batch; } std::string* BradStatement::GetBradStmt() const { return stmt_; } diff --git a/cpp/server/brad_statement.h b/cpp/server/brad_statement.h index 482829c9..b3dba2cc 100644 --- a/cpp/server/brad_statement.h +++ b/cpp/server/brad_statement.h @@ -1,7 +1,9 @@ #pragma once #include +#include #include +#include #include #include @@ -23,6 +25,11 @@ class BradStatement { static arrow::Result> Create( const std::string& sql); + static arrow::Result> Create( + const std::vector>); + + BradStatement(std::vector>); + ~BradStatement(); /// \brief Creates an Arrow Schema based on the results of this statement. @@ -34,6 +41,10 @@ class BradStatement { std::string* GetBradStmt() const; private: + std::vector> query_result_; + + mutable std::shared_ptr schema_; + std::string* stmt_; BradStatement(std::string* stmt) : stmt_(stmt) {} diff --git a/cpp/server/brad_statement_batch_reader.cc b/cpp/server/brad_statement_batch_reader.cc index 16ef38cd..48c9d5f2 100644 --- a/cpp/server/brad_statement_batch_reader.cc +++ b/cpp/server/brad_statement_batch_reader.cc @@ -13,7 +13,8 @@ BradStatementBatchReader::BradStatementBatchReader( std::shared_ptr statement, std::shared_ptr schema) : statement_(std::move(statement)), - schema_(std::move(schema)) {} + schema_(std::move(schema)), + already_executed_(false) {} arrow::Result> BradStatementBatchReader::Create( diff --git a/cpp/third_party/CMakeLists.txt b/cpp/third_party/CMakeLists.txt index 002e0b50..8ba1d026 100644 --- a/cpp/third_party/CMakeLists.txt +++ b/cpp/third_party/CMakeLists.txt @@ -12,4 +12,10 @@ FetchContent_Declare( GIT_TAG v2.2.2 ) -FetchContent_MakeAvailable(pybind11 gflags) +FetchContent_Declare( + libcuckoo + GIT_REPOSITORY https://github.com/efficient/libcuckoo.git + GIT_TAG 784d0f5d147b9a73f897ae55f6c3712d9a91b058 +) + +FetchContent_MakeAvailable(pybind11 gflags libcuckoo) diff --git a/src/brad/front_end/flight_sql_server.py b/src/brad/front_end/flight_sql_server.py index eb736e20..22152e8e 100644 --- a/src/brad/front_end/flight_sql_server.py +++ b/src/brad/front_end/flight_sql_server.py @@ -1,5 +1,6 @@ import logging import threading +from typing import Callable # pylint: disable-next=import-error,no-name-in-module,unused-import import brad.native.pybind_brad_server as brad_server @@ -8,9 +9,9 @@ class BradFlightSqlServer: - def __init__(self, host: str, port: int) -> None: + def __init__(self, host: str, port: int, callback: Callable) -> None: self._flight_sql_server = brad_server.BradFlightSqlServer() - self._flight_sql_server.init(host, port) + self._flight_sql_server.init(host, port, callback) self._thread = threading.Thread(name="BradFlightSqlServer", target=self._serve) def start(self) -> None: diff --git a/src/brad/front_end/front_end.py b/src/brad/front_end/front_end.py index d00b47e0..f460e2bf 100644 --- a/src/brad/front_end/front_end.py +++ b/src/brad/front_end/front_end.py @@ -90,11 +90,18 @@ def __init__( from brad.front_end.flight_sql_server import BradFlightSqlServer self._flight_sql_server: Optional[BradFlightSqlServer] = ( - BradFlightSqlServer(host="0.0.0.0", port=31337) + BradFlightSqlServer( + host="0.0.0.0", + port=31337, + callback=self._handle_query_from_flight_sql, + ) ) + self._flight_sql_server_session_id: Optional[SessionId] = None else: self._flight_sql_server = None + self._main_thread_loop: Optional[asyncio.AbstractEventLoop] = None + self._fe_index = fe_index self._config = config self._schema_name = schema_name @@ -190,11 +197,24 @@ def __init__( self._is_stub_mode = self._config.stub_mode_path is not None + def _handle_query_from_flight_sql(self, query: str) -> RowList: + assert self._flight_sql_server_session_id is not None + assert self._main_thread_loop is not None + + future = asyncio.run_coroutine_threadsafe( + self._run_query_impl(self._flight_sql_server_session_id, query, {}), + self._main_thread_loop, + ) + row_result = future.result() + + return row_result + async def serve_forever(self): await self._run_setup() # Start FlightSQL server if self._flight_sql_server is not None: + self._flight_sql_server_session_id = await self.start_session() self._flight_sql_server.start() try: @@ -219,6 +239,8 @@ async def serve_forever(self): logger.debug("BRAD front end _run_teardown() complete.") async def _run_setup(self) -> None: + self._main_thread_loop = asyncio.get_running_loop() + # The directory will have been populated by the daemon. await self._blueprint_mgr.load(skip_directory_refresh=True) logger.info("Using blueprint: %s", self._blueprint_mgr.get_blueprint()) @@ -239,7 +261,7 @@ async def _run_setup(self) -> None: if not self._is_stub_mode: self._qlogger_refresh_task = asyncio.create_task(self._refresh_qlogger()) - self._watchdog.start(asyncio.get_running_loop()) + self._watchdog.start(self._main_thread_loop) self._ping_watchdog_task = asyncio.create_task(self._ping_watchdog()) async def _set_up_router(self) -> None: