Skip to content

Commit

Permalink
Replace direct cudaMemcpyAsync calls with utility functions (limite…
Browse files Browse the repository at this point in the history
…d to `cudf::io`) (rapidsai#17132)

Issue rapidsai#15620

Replaced the calls to `cudaMemcpyAsync` with the new `cuda_memcpy`/`cuda_memcpy_async` utility, which optionally avoids using the copy engine. Changes are limited to cuIO to make the PR easier to review (repetitive enough as-is!).

Also took the opportunity to use `cudf::detail::host_vector` and its factories to enable wider pinned memory use.

Skipped a few instances of `cudaMemcpyAsync`; few are under `io::comp`, which we don't want to invest in further (if possible). The other `cudaMemcpyAsync` instances are D2D copies, which `cuda_memcpy`/`cuda_memcpy_async` don't support. Perhaps they should, just to make the use ubiquitous.

Authors:
  - Vukasin Milovanovic (https://github.com/vuule)

Approvers:
  - Paul Mattione (https://github.com/pmattione-nvidia)
  - Nghia Truong (https://github.com/ttnghia)

URL: rapidsai#17132
  • Loading branch information
vuule authored Oct 23, 2024
1 parent 02ee819 commit deb9af4
Show file tree
Hide file tree
Showing 18 changed files with 134 additions and 172 deletions.
6 changes: 4 additions & 2 deletions cpp/src/io/comp/uncomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,8 +538,10 @@ size_t decompress_zstd(host_span<uint8_t const> src,
CUDF_EXPECTS(hd_stats[0].status == compression_status::SUCCESS, "ZSTD decompression failed");

// Copy temporary output to `dst`
CUDF_CUDA_TRY(cudaMemcpyAsync(
dst.data(), d_dst.data(), hd_stats[0].bytes_written, cudaMemcpyDefault, stream.value()));
cudf::detail::cuda_memcpy_async(
dst.subspan(0, hd_stats[0].bytes_written),
device_span<uint8_t const>{d_dst.data(), hd_stats[0].bytes_written},
stream);

return hd_stats[0].bytes_written;
}
Expand Down
57 changes: 23 additions & 34 deletions cpp/src/io/csv/reader_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include "csv_common.hpp"
#include "csv_gpu.hpp"
#include "cudf/detail/utilities/cuda_memcpy.hpp"
#include "io/comp/io_uncomp.hpp"
#include "io/utilities/column_buffer.hpp"
#include "io/utilities/hostdevice_vector.hpp"
Expand Down Expand Up @@ -275,11 +276,10 @@ std::pair<rmm::device_uvector<char>, selected_rows_offsets> load_data_and_gather
auto const read_offset = byte_range_offset + input_pos + previous_data_size;
auto const read_size = target_pos - input_pos - previous_data_size;
if (data.has_value()) {
CUDF_CUDA_TRY(cudaMemcpyAsync(d_data.data() + previous_data_size,
data->data() + read_offset,
target_pos - input_pos - previous_data_size,
cudaMemcpyDefault,
stream.value()));
cudf::detail::cuda_memcpy_async(
device_span<char>{d_data.data() + previous_data_size, read_size},
data->subspan(read_offset, read_size),
stream);
} else {
if (source->is_device_read_preferred(read_size)) {
source->device_read(read_offset,
Expand All @@ -288,12 +288,11 @@ std::pair<rmm::device_uvector<char>, selected_rows_offsets> load_data_and_gather
stream);
} else {
auto const buffer = source->host_read(read_offset, read_size);
CUDF_CUDA_TRY(cudaMemcpyAsync(d_data.data() + previous_data_size,
buffer->data(),
buffer->size(),
cudaMemcpyDefault,
stream.value()));
stream.synchronize(); // To prevent buffer going out of scope before we copy the data.
// Use sync version to prevent buffer going out of scope before we copy the data.
cudf::detail::cuda_memcpy(
device_span<char>{d_data.data() + previous_data_size, read_size},
host_span<char const>{reinterpret_cast<char const*>(buffer->data()), buffer->size()},
stream);
}
}

Expand All @@ -311,12 +310,10 @@ std::pair<rmm::device_uvector<char>, selected_rows_offsets> load_data_and_gather
range_end,
skip_rows,
stream);
CUDF_CUDA_TRY(cudaMemcpyAsync(row_ctx.host_ptr(),
row_ctx.device_ptr(),
num_blocks * sizeof(uint64_t),
cudaMemcpyDefault,
stream.value()));
stream.synchronize();

cudf::detail::cuda_memcpy(host_span<uint64_t>{row_ctx}.subspan(0, num_blocks),
device_span<uint64_t const>{row_ctx}.subspan(0, num_blocks),
stream);

// Sum up the rows in each character block, selecting the row count that
// corresponds to the current input context. Also stores the now known input
Expand All @@ -331,11 +328,9 @@ std::pair<rmm::device_uvector<char>, selected_rows_offsets> load_data_and_gather
// At least one row in range in this batch
all_row_offsets.resize(total_rows - skip_rows, stream);

CUDF_CUDA_TRY(cudaMemcpyAsync(row_ctx.device_ptr(),
row_ctx.host_ptr(),
num_blocks * sizeof(uint64_t),
cudaMemcpyDefault,
stream.value()));
cudf::detail::cuda_memcpy_async(device_span<uint64_t>{row_ctx}.subspan(0, num_blocks),
host_span<uint64_t const>{row_ctx}.subspan(0, num_blocks),
stream);

// Pass 2: Output row offsets
cudf::io::csv::gpu::gather_row_offsets(parse_opts.view(),
Expand All @@ -352,12 +347,9 @@ std::pair<rmm::device_uvector<char>, selected_rows_offsets> load_data_and_gather
stream);
// With byte range, we want to keep only one row out of the specified range
if (range_end < data_size) {
CUDF_CUDA_TRY(cudaMemcpyAsync(row_ctx.host_ptr(),
row_ctx.device_ptr(),
num_blocks * sizeof(uint64_t),
cudaMemcpyDefault,
stream.value()));
stream.synchronize();
cudf::detail::cuda_memcpy(host_span<uint64_t>{row_ctx}.subspan(0, num_blocks),
device_span<uint64_t const>{row_ctx}.subspan(0, num_blocks),
stream);

size_t rows_out_of_range = 0;
for (uint32_t i = 0; i < num_blocks; i++) {
Expand Down Expand Up @@ -401,12 +393,9 @@ std::pair<rmm::device_uvector<char>, selected_rows_offsets> load_data_and_gather
// Remove header rows and extract header
auto const header_row_index = std::max<size_t>(header_rows, 1) - 1;
if (header_row_index + 1 < row_offsets.size()) {
CUDF_CUDA_TRY(cudaMemcpyAsync(row_ctx.host_ptr(),
row_offsets.data() + header_row_index,
2 * sizeof(uint64_t),
cudaMemcpyDefault,
stream.value()));
stream.synchronize();
cudf::detail::cuda_memcpy(host_span<uint64_t>{row_ctx}.subspan(0, 2),
device_span<uint64_t const>{row_offsets.data() + header_row_index, 2},
stream);

auto const header_start = input_pos + row_ctx[0];
auto const header_end = input_pos + row_ctx[1];
Expand Down
10 changes: 3 additions & 7 deletions cpp/src/io/csv/writer_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <cudf/detail/copy.hpp>
#include <cudf/detail/fill.hpp>
#include <cudf/detail/null_mask.hpp>
#include <cudf/detail/utilities/vector_factories.hpp>
#include <cudf/io/data_sink.hpp>
#include <cudf/io/detail/csv.hpp>
#include <cudf/null_mask.hpp>
Expand Down Expand Up @@ -405,13 +406,8 @@ void write_chunked(data_sink* out_sink,
out_sink->device_write(ptr_all_bytes, total_num_bytes, stream);
} else {
// copy the bytes to host to write them out
thrust::host_vector<char> h_bytes(total_num_bytes);
CUDF_CUDA_TRY(cudaMemcpyAsync(h_bytes.data(),
ptr_all_bytes,
total_num_bytes * sizeof(char),
cudaMemcpyDefault,
stream.value()));
stream.synchronize();
auto const h_bytes = cudf::detail::make_host_vector_sync(
device_span<char const>{ptr_all_bytes, total_num_bytes}, stream);

out_sink->host_write(h_bytes.data(), total_num_bytes);
}
Expand Down
24 changes: 8 additions & 16 deletions cpp/src/io/json/host_tree_algorithms.cu
Original file line number Diff line number Diff line change
Expand Up @@ -198,17 +198,13 @@ NodeIndexT get_row_array_parent_col_id(device_span<NodeIndexT const> col_ids,
bool is_enabled_lines,
rmm::cuda_stream_view stream)
{
NodeIndexT value = parent_node_sentinel;
if (!col_ids.empty()) {
auto const list_node_index = is_enabled_lines ? 0 : 1;
CUDF_CUDA_TRY(cudaMemcpyAsync(&value,
col_ids.data() + list_node_index,
sizeof(NodeIndexT),
cudaMemcpyDefault,
stream.value()));
stream.synchronize();
}
return value;
if (col_ids.empty()) { return parent_node_sentinel; }

auto const list_node_index = is_enabled_lines ? 0 : 1;
auto const value = cudf::detail::make_host_vector_sync(
device_span<NodeIndexT const>{col_ids.data() + list_node_index, 1}, stream);

return value[0];
}
/**
* @brief Holds member data pointers of `d_json_column`
Expand Down Expand Up @@ -818,11 +814,7 @@ std::pair<cudf::detail::host_vector<bool>, hashmap_of_device_columns> build_tree
column_categories.cbegin(),
expected_types.begin(),
[](auto exp, auto cat) { return exp == NUM_NODE_CLASSES ? cat : exp; });
cudaMemcpyAsync(d_column_tree.node_categories.begin(),
expected_types.data(),
expected_types.size() * sizeof(column_categories[0]),
cudaMemcpyDefault,
stream.value());
cudf::detail::cuda_memcpy_async<NodeT>(d_column_tree.node_categories, expected_types, stream);

return {is_pruned, columns};
}
Expand Down
16 changes: 7 additions & 9 deletions cpp/src/io/json/json_column.cu
Original file line number Diff line number Diff line change
Expand Up @@ -513,16 +513,14 @@ table_with_metadata device_parse_nested_json(device_span<SymbolT const> d_input,
#endif

bool const is_array_of_arrays = [&]() {
std::array<node_t, 2> h_node_categories = {NC_ERR, NC_ERR};
auto const size_to_copy = std::min(size_t{2}, gpu_tree.node_categories.size());
CUDF_CUDA_TRY(cudaMemcpyAsync(h_node_categories.data(),
gpu_tree.node_categories.data(),
sizeof(node_t) * size_to_copy,
cudaMemcpyDefault,
stream.value()));
stream.synchronize();
auto const size_to_copy = std::min(size_t{2}, gpu_tree.node_categories.size());
if (size_to_copy == 0) return false;
auto const h_node_categories = cudf::detail::make_host_vector_sync(
device_span<NodeT const>{gpu_tree.node_categories.data(), size_to_copy}, stream);

if (options.is_enabled_lines()) return h_node_categories[0] == NC_LIST;
return h_node_categories[0] == NC_LIST and h_node_categories[1] == NC_LIST;
return h_node_categories.size() >= 2 and h_node_categories[0] == NC_LIST and
h_node_categories[1] == NC_LIST;
}();

auto [gpu_col_id, gpu_row_offsets] =
Expand Down
15 changes: 6 additions & 9 deletions cpp/src/io/json/json_tree.cu
Original file line number Diff line number Diff line change
Expand Up @@ -264,16 +264,13 @@ tree_meta_t get_tree_representation(device_span<PdaTokenT const> tokens,
error_count > 0) {
auto const error_location =
thrust::find(rmm::exec_policy(stream), tokens.begin(), tokens.end(), token_t::ErrorBegin);
SymbolOffsetT error_index;
CUDF_CUDA_TRY(
cudaMemcpyAsync(&error_index,
token_indices.data() + thrust::distance(tokens.begin(), error_location),
sizeof(SymbolOffsetT),
cudaMemcpyDefault,
stream.value()));
stream.synchronize();
auto error_index = cudf::detail::make_host_vector_sync<SymbolOffsetT>(
device_span<SymbolOffsetT const>{
token_indices.data() + thrust::distance(tokens.begin(), error_location), 1},
stream);

CUDF_FAIL("JSON Parser encountered an invalid format at location " +
std::to_string(error_index));
std::to_string(error_index[0]));
}

auto const num_tokens = tokens.size();
Expand Down
13 changes: 6 additions & 7 deletions cpp/src/io/json/read_json.cu
Original file line number Diff line number Diff line change
Expand Up @@ -315,13 +315,12 @@ device_span<char> ingest_raw_input(device_span<char> buffer,
// Reading to host because decompression of a single block is much faster on the CPU
sources[0]->host_read(range_offset, remaining_bytes_to_read, hbuffer.data());
auto uncomp_data = decompress(compression, hbuffer);
CUDF_CUDA_TRY(cudaMemcpyAsync(buffer.data(),
reinterpret_cast<char*>(uncomp_data.data()),
uncomp_data.size() * sizeof(char),
cudaMemcpyHostToDevice,
stream.value()));
stream.synchronize();
return buffer.first(uncomp_data.size());
auto ret_buffer = buffer.first(uncomp_data.size());
cudf::detail::cuda_memcpy<char>(
ret_buffer,
host_span<char const>{reinterpret_cast<char const*>(uncomp_data.data()), uncomp_data.size()},
stream);
return ret_buffer;
}

table_with_metadata read_json(host_span<std::unique_ptr<datasource>> sources,
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/io/orc/reader_impl_chunking.cu
Original file line number Diff line number Diff line change
Expand Up @@ -516,10 +516,10 @@ void reader_impl::load_next_stripe_data(read_mode mode)
_stream.synchronize();
stream_synchronized = true;
}
device_read_tasks.push_back(
std::pair(source_ptr->device_read_async(
read_info.offset, read_info.length, dst_base + read_info.dst_pos, _stream),
read_info.length));
device_read_tasks.emplace_back(
source_ptr->device_read_async(
read_info.offset, read_info.length, dst_base + read_info.dst_pos, _stream),
read_info.length);

} else {
auto buffer = source_ptr->host_read(read_info.offset, read_info.length);
Expand Down
26 changes: 13 additions & 13 deletions cpp/src/io/orc/writer_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
* @brief cuDF-IO ORC writer class implementation
*/

#include "cudf/detail/utilities/cuda_memcpy.hpp"
#include "io/comp/nvcomp_adapter.hpp"
#include "io/orc/orc_gpu.hpp"
#include "io/statistics/column_statistics.cuh"
Expand Down Expand Up @@ -1408,7 +1409,8 @@ encoded_footer_statistics finish_statistic_blobs(Footer const& footer,
num_entries_seen += stripes_per_col;
}

std::vector<statistics_merge_group> file_stats_merge(num_file_blobs);
auto file_stats_merge =
cudf::detail::make_host_vector<statistics_merge_group>(num_file_blobs, stream);
for (auto i = 0u; i < num_file_blobs; ++i) {
auto col_stats = &file_stats_merge[i];
col_stats->col_dtype = per_chunk_stats.col_types[i];
Expand All @@ -1418,11 +1420,10 @@ encoded_footer_statistics finish_statistic_blobs(Footer const& footer,
}

auto d_file_stats_merge = stats_merge.device_ptr(num_stripe_blobs);
CUDF_CUDA_TRY(cudaMemcpyAsync(d_file_stats_merge,
file_stats_merge.data(),
num_file_blobs * sizeof(statistics_merge_group),
cudaMemcpyDefault,
stream.value()));
cudf::detail::cuda_memcpy_async<statistics_merge_group>(
device_span<statistics_merge_group>{stats_merge.device_ptr(num_stripe_blobs), num_file_blobs},
file_stats_merge,
stream);

auto file_stat_chunks = stat_chunks.data() + num_stripe_blobs;
detail::merge_group_statistics<detail::io_file_format::ORC>(
Expand Down Expand Up @@ -1573,7 +1574,7 @@ void write_index_stream(int32_t stripe_id,
* @param[in] strm_desc Stream's descriptor
* @param[in] enc_stream Chunk's streams
* @param[in] compressed_data Compressed stream data
* @param[in,out] stream_out Temporary host output buffer
* @param[in,out] bounce_buffer Pinned memory bounce buffer for D2H data transfer
* @param[in,out] stripe Stream's parent stripe
* @param[in,out] streams List of all streams
* @param[in] compression_kind The compression kind
Expand All @@ -1584,7 +1585,7 @@ void write_index_stream(int32_t stripe_id,
std::future<void> write_data_stream(gpu::StripeStream const& strm_desc,
gpu::encoder_chunk_streams const& enc_stream,
uint8_t const* compressed_data,
uint8_t* stream_out,
host_span<uint8_t> bounce_buffer,
StripeInformation* stripe,
orc_streams* streams,
CompressionKind compression_kind,
Expand All @@ -1604,11 +1605,10 @@ std::future<void> write_data_stream(gpu::StripeStream const& strm_desc,
if (out_sink->is_device_write_preferred(length)) {
return out_sink->device_write_async(stream_in, length, stream);
} else {
CUDF_CUDA_TRY(
cudaMemcpyAsync(stream_out, stream_in, length, cudaMemcpyDefault, stream.value()));
stream.synchronize();
cudf::detail::cuda_memcpy(
bounce_buffer.subspan(0, length), device_span<uint8_t const>{stream_in, length}, stream);

out_sink->host_write(stream_out, length);
out_sink->host_write(bounce_buffer.data(), length);
return std::async(std::launch::deferred, [] {});
}
}();
Expand Down Expand Up @@ -2616,7 +2616,7 @@ void writer::impl::write_orc_data_to_sink(encoded_data const& enc_data,
strm_desc,
enc_data.streams[strm_desc.column_id][segmentation.stripes[stripe_id].first],
compressed_data.data(),
bounce_buffer.data(),
bounce_buffer,
&stripe,
&streams,
_compression_kind,
Expand Down
21 changes: 12 additions & 9 deletions cpp/src/io/parquet/predicate_pushdown.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -454,15 +454,18 @@ std::optional<std::vector<std::vector<size_type>>> aggregate_reader_metadata::fi
CUDF_EXPECTS(predicate.type().id() == cudf::type_id::BOOL8,
"Filter expression must return a boolean column");

auto num_bitmasks = num_bitmask_words(predicate.size());
std::vector<bitmask_type> host_bitmask(num_bitmasks, ~bitmask_type{0});
if (predicate.nullable()) {
CUDF_CUDA_TRY(cudaMemcpyAsync(host_bitmask.data(),
predicate.null_mask(),
num_bitmasks * sizeof(bitmask_type),
cudaMemcpyDefault,
stream.value()));
}
auto const host_bitmask = [&] {
auto const num_bitmasks = num_bitmask_words(predicate.size());
if (predicate.nullable()) {
return cudf::detail::make_host_vector_sync(
device_span<bitmask_type const>(predicate.null_mask(), num_bitmasks), stream);
} else {
auto bitmask = cudf::detail::make_host_vector<bitmask_type>(num_bitmasks, stream);
std::fill(bitmask.begin(), bitmask.end(), ~bitmask_type{0});
return bitmask;
}
}();

auto validity_it = cudf::detail::make_counting_transform_iterator(
0, [bitmask = host_bitmask.data()](auto bit_index) { return bit_is_set(bitmask, bit_index); });

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/io/parquet/reader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ void reader::impl::decode_page_data(read_mode mode, size_t skip_rows, size_t num
// TODO: This step is somewhat redundant if size info has already been calculated (nested schema,
// chunked reader).
auto const has_strings = (kernel_mask & STRINGS_MASK) != 0;
std::vector<size_t> col_string_sizes(_input_columns.size(), 0L);
auto col_string_sizes = cudf::detail::make_host_vector<size_t>(_input_columns.size(), _stream);
if (has_strings) {
// need to compute pages bounds/sizes if we lack page indexes or are using custom bounds
// TODO: we could probably dummy up size stats for FLBA data since we know the width
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/io/parquet/reader_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ class reader::impl {
*
* @return Vector of total string data sizes for each column
*/
std::vector<size_t> calculate_page_string_offsets();
cudf::detail::host_vector<size_t> calculate_page_string_offsets();

/**
* @brief Converts the page data and outputs to columns.
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/io/parquet/reader_impl_chunking.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ struct subpass_intermediate_data {
* rowgroups may represent less than all of the rowgroups to be read for the file.
*/
struct pass_intermediate_data {
std::vector<std::unique_ptr<datasource::buffer>> raw_page_data;
std::vector<rmm::device_buffer> raw_page_data;

// rowgroup, chunk and page information for the current pass.
bool has_compressed_data{false};
Expand Down
Loading

0 comments on commit deb9af4

Please sign in to comment.