Skip to content

Commit

Permalink
Separate Decimal field type handler, check for Null fields, handle nu…
Browse files Browse the repository at this point in the history
…lltype data
  • Loading branch information
Sophie Zhang committed Apr 29, 2024
1 parent a7ac465 commit ea7c72d
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 18 deletions.
74 changes: 56 additions & 18 deletions cpp/server/brad_server_simple.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ arrow::Result<std::pair<std::string, std::string>> DecodeTransactionQuery(
return std::make_pair(std::move(autoincrement_id), std::move(transaction_id));
}

arrow::Result<std::shared_ptr<arrow::RecordBatch>>
ResultToRecordBatch(std::vector<py::tuple> query_result, std::shared_ptr<arrow::Schema> schema) {
const int num_rows = query_result.size();
arrow::Result<std::shared_ptr<arrow::RecordBatch>> ResultToRecordBatch(
const std::vector<py::tuple> &query_result,
const std::shared_ptr<arrow::Schema> &schema) {
const size_t num_rows = query_result.size();

const int num_columns = schema->num_fields();
const size_t num_columns = schema->num_fields();
std::vector<std::shared_ptr<arrow::Array>> columns;
columns.reserve(num_columns);

Expand All @@ -64,32 +65,56 @@ ResultToRecordBatch(std::vector<py::tuple> query_result, std::shared_ptr<arrow::
if (field_type->Equals(arrow::int64())) {
arrow::Int64Builder int64builder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const int64_t val = py::cast<int64_t>(query_result[row_ix][field_ix]);
// TODO: How do we check for null values in ints or floats?
ARROW_RETURN_NOT_OK(int64builder.Append(val));
const std::optional<int64_t> val =
py::cast<std::optional<int64_t>>(query_result[row_ix][field_ix]);
if (val) {
ARROW_RETURN_NOT_OK(int64builder.Append(*val));
} else {
ARROW_RETURN_NOT_OK(int64builder.AppendNull());
}
}
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, int64builder.Finish());
columns.push_back(values);

} else if (field_type->Equals(arrow::float32()) ||
// TODO: Should not hardcode precision and scale values
field_type->Equals(arrow::decimal(/*precision=*/10, /*scale=*/2))) {
} else if (field_type->Equals(arrow::float32())) {
arrow::FloatBuilder floatbuilder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const float val = py::cast<float>(query_result[row_ix][field_ix]);
ARROW_RETURN_NOT_OK(floatbuilder.Append(val));
const std::optional<float> val =
py::cast<std::optional<float>>(query_result[row_ix][field_ix]);
if (val) {
ARROW_RETURN_NOT_OK(floatbuilder.Append(*val));
} else {
ARROW_RETURN_NOT_OK(floatbuilder.AppendNull());
}
}
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, floatbuilder.Finish());
columns.push_back(values);

} else if (field_type->Equals(arrow::decimal(/*precision=*/10, /*scale=*/2))) {
arrow::Decimal128Builder decimalbuilder(arrow::decimal(/*precision=*/10, /*scale=*/2));
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const std::optional<std::string> val =
py::cast<std::optional<std::string>>(query_result[row_ix][field_ix]);
if (val) {
ARROW_RETURN_NOT_OK(
decimalbuilder.Append(arrow::Decimal128::FromString(*val).ValueOrDie()));
} else {
ARROW_RETURN_NOT_OK(decimalbuilder.AppendNull());
}
}
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, decimalbuilder.Finish());
columns.push_back(values);

} else if (field_type->Equals(arrow::utf8())) {
arrow::StringBuilder stringbuilder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const std::string str = py::cast<std::string>(query_result[row_ix][field_ix]);
if (str.empty()) {
ARROW_RETURN_NOT_OK(stringbuilder.Append(str.data(), str.size()));
const std::optional<std::string> str =
py::cast<std::optional<std::string>>(query_result[row_ix][field_ix]);
if (str) {
ARROW_RETURN_NOT_OK(stringbuilder.Append(str->data(), str->size()));
} else {
ARROW_RETURN_NOT_OK(stringbuilder.AppendNull());
}
Expand All @@ -101,13 +126,26 @@ ResultToRecordBatch(std::vector<py::tuple> query_result, std::shared_ptr<arrow::
} else if (field_type->Equals(arrow::date64())) {
arrow::Date64Builder datebuilder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
const int64_t val = py::cast<int64_t>(query_result[row_ix][field_ix]);
ARROW_RETURN_NOT_OK(datebuilder.Append(val));
const std::optional<int64_t> val =
py::cast<std::optional<int64_t>>(query_result[row_ix][field_ix]);
if (val) {
ARROW_RETURN_NOT_OK(datebuilder.Append(*val));
} else {
ARROW_RETURN_NOT_OK(datebuilder.AppendNull());
}
}
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, datebuilder.Finish());
columns.push_back(values);

} else if (field_type->Equals(arrow::null())) {
arrow::NullBuilder nullbuilder;
for (int row_ix = 0; row_ix < num_rows; ++row_ix) {
ARROW_RETURN_NOT_OK(nullbuilder.AppendNull());
}
std::shared_ptr<arrow::Array> values;
ARROW_ASSIGN_OR_RAISE(values, nullbuilder.Finish());
columns.push_back(values);
}
}

Expand Down Expand Up @@ -179,7 +217,7 @@ arrow::Result<std::unique_ptr<FlightInfo>>
py::gil_scoped_acquire guard;
auto result = handle_query_(query);
result_schema = ArrowSchemaFromBradSchema(result.second);
result_record_batch = ResultToRecordBatch(result.first, result_schema).ValueOrDie();
result_record_batch = ResultToRecordBatch(std::move(result.first), result_schema).ValueOrDie();
}

ARROW_ASSIGN_OR_RAISE(auto statement, BradStatement::Create(result_record_batch, result_schema));
Expand Down
1 change: 1 addition & 0 deletions cpp/server/brad_server_simple.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "libcuckoo/cuckoohash_map.hh"

#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

namespace brad {

Expand Down

0 comments on commit ea7c72d

Please sign in to comment.