diff --git a/src/distributed_graph_generator.cc b/src/distributed_graph_generator.cc index f0b885b9a..683099b11 100644 --- a/src/distributed_graph_generator.cc +++ b/src/distributed_graph_generator.cc @@ -184,8 +184,8 @@ void distributed_graph_generator::generate_distributed_commands(const task& tsk) // TODO: Revisit this at some point. const node_id reduction_initializer_nid = 0; - const box<3> empty_box({0, 0, 0}, {0, 0, 0}); - const box<3> scalar_box({0, 0, 0}, {1, 1, 1}); + const box<3> empty_reduction_box({0, 0, 0}, {0, 0, 0}); + const box<3> scalar_reduction_box({0, 0, 0}, {1, 1, 1}); // Iterate over all chunks, distinguish between local / remote chunks and normal / reduction access. // @@ -216,7 +216,7 @@ void distributed_graph_generator::generate_distributed_commands(const task& tsk) assert(requirements[reduction.bid].count(pmode) == 0); // task_manager verifies that there are no reduction <-> write-access conflicts } #endif - requirements[reduction.bid][rmode] = scalar_box; + requirements[reduction.bid][rmode] = scalar_reduction_box; } abstract_command* cmd = nullptr; @@ -356,7 +356,7 @@ void distributed_graph_generator::generate_distributed_commands(const task& tsk) if(generate_reduction) { const auto& reduction = *buffer_state.pending_reduction; - const auto local_last_writer = buffer_state.local_last_writer.get_region_values(scalar_box); + const auto local_last_writer = buffer_state.local_last_writer.get_region_values(scalar_reduction_box); assert(local_last_writer.size() == 1); if(is_local_chunk) { @@ -367,18 +367,18 @@ void distributed_graph_generator::generate_distributed_commands(const task& tsk) m_cdag.add_dependency(reduce_cmd, m_cdag.get(local_last_writer[0].second), dependency_kind::true_dep, dependency_origin::dataflow); } - auto* const ap_cmd = create_command(bid, reduction.rid, trid, scalar_box.get_subrange()); + auto* const ap_cmd = create_command(bid, reduction.rid, trid, scalar_reduction_box.get_subrange()); m_cdag.add_dependency(reduce_cmd, ap_cmd, dependency_kind::true_dep, dependency_origin::dataflow); generate_epoch_dependencies(ap_cmd); m_cdag.add_dependency(cmd, reduce_cmd, dependency_kind::true_dep, dependency_origin::dataflow); // Reduction command becomes the last writer (this may be overriden if this task also writes to the reduction buffer) - post_reduction_buffer_states.at(bid).local_last_writer.update_box(scalar_box, reduce_cmd->get_cid()); + post_reduction_buffer_states.at(bid).local_last_writer.update_box(scalar_reduction_box, reduce_cmd->get_cid()); } else { // Push an empty range if we don't have any fresh data on this node const bool notification_only = !local_last_writer[0].second.is_fresh(); - const auto push_box = notification_only ? empty_box : scalar_box; + const auto push_box = notification_only ? empty_reduction_box : scalar_reduction_box; auto* const push_cmd = create_command(bid, reduction.rid, nid, trid, push_box.get_subrange()); generated_pushes.push_back(push_cmd); @@ -386,16 +386,16 @@ void distributed_graph_generator::generate_distributed_commands(const task& tsk) if(notification_only) { generate_epoch_dependencies(push_cmd); } else { - m_command_buffer_reads[push_cmd->get_cid()][bid] = region_union(m_command_buffer_reads[push_cmd->get_cid()][bid], scalar_box); + m_command_buffer_reads[push_cmd->get_cid()][bid] = region_union(m_command_buffer_reads[push_cmd->get_cid()][bid], scalar_reduction_box); m_cdag.add_dependency(push_cmd, m_cdag.get(local_last_writer[0].second), dependency_kind::true_dep, dependency_origin::dataflow); } // Mark the reduction result as replicated so we don't generate data transfers to this node // TODO: We need a way of updating regions in place! E.g. apply_to_values(box, callback) - const auto replicated_box = post_reduction_buffer_states.at(bid).replicated_regions.get_region_values(scalar_box); + const auto replicated_box = post_reduction_buffer_states.at(bid).replicated_regions.get_region_values(scalar_reduction_box); assert(replicated_box.size() == 1); for(const auto& [_, nodes] : replicated_box) { - post_reduction_buffer_states.at(bid).replicated_regions.update_box(scalar_box, node_bitset{nodes}.set(nid)); + post_reduction_buffer_states.at(bid).replicated_regions.update_box(scalar_reduction_box, node_bitset{nodes}.set(nid)); } } } @@ -484,6 +484,11 @@ void distributed_graph_generator::generate_distributed_commands(const task& tsk) // Determine which local data is fresh/stale based on task-level writes. auto requirements = get_buffer_requirements_for_mapped_access(tsk, subrange<3>(tsk.get_global_offset(), tsk.get_global_size()), tsk.get_global_size()); + // Add requirements for reductions + for(const auto& reduction : tsk.get_reductions()) { + // the actual mode is irrelevant as long as it's a producer - TODO have a better query API for task buffer requirements + requirements[reduction.bid][access_mode::write] = scalar_reduction_box; + } for(auto& [bid, reqs_by_mode] : requirements) { box_vector<3> global_write_boxes; for(const auto mode : access::producer_modes) { diff --git a/test/graph_gen_reduction_tests.cc b/test/graph_gen_reduction_tests.cc index c1acad24a..91fc6fded 100644 --- a/test/graph_gen_reduction_tests.cc +++ b/test/graph_gen_reduction_tests.cc @@ -219,7 +219,7 @@ TEST_CASE("reduction commands anti-depend on their partial-result push commands" auto buf = dctx.create_buffer(range<1>(1)); const auto tid_producer = dctx.device_compute(range<1>(num_nodes)).reduce(buf, false /* include_current_buffer_value */).submit(); - const auto tid_consumer = dctx.device_compute(range<1>(num_nodes)).read(buf, acc::all{}).submit(); + /* const auto tid_consumer = */ dctx.device_compute(range<1>(num_nodes)).read(buf, acc::all{}).submit(); CHECK(dctx.query(tid_producer) .assert_count_per_node(1) @@ -227,3 +227,19 @@ TEST_CASE("reduction commands anti-depend on their partial-result push commands" .assert_count_per_node(1) .have_successors(dctx.query(command_type::reduction).assert_count_per_node(1), dependency_kind::anti_dep)); } + +TEST_CASE("reduction in a single-node task does not generate a reduction command, but the result is await-pushed on other nodes", + "[distributed_graph_generator][command-graph][reductions]") { + const size_t num_nodes = 3; + dist_cdag_test_context dctx(num_nodes); + auto buf = dctx.create_buffer(range<1>(1)); + + const auto tid_producer = dctx.device_compute(range<1>(1)).reduce(buf, false /* include_current_buffer_value */).submit(); + const auto tid_consumer = dctx.device_compute(range<1>(num_nodes)).read(buf, acc::all()).submit(); + + CHECK(dctx.query(command_type::reduction).count() == 0); + CHECK(dctx.query(tid_producer).assert_count(1).have_successors(dctx.query(node_id(0), command_type::push).assert_count(2))); + for(node_id nid_await : {node_id(1), node_id(2)}) { + CHECK(dctx.query(nid_await, command_type::await_push).assert_count(1).have_successors(dctx.query(nid_await, tid_consumer))); + } +}