From 6c0bc673625ad5e81b4590031c72bc0be1fc467c Mon Sep 17 00:00:00 2001 From: Victor Stone Date: Fri, 17 Jan 2025 16:49:31 -0800 Subject: [PATCH] Change HostOffloader to mark every DynamicUpdateSlice which operates on host memory as host compute. This, of course, excludes DynamicUpdateSlices which are used for host offloading DMAs. PiperOrigin-RevId: 716839236 --- xla/hlo/transforms/host_offloader.cc | 114 ++++++++++++++++++---- xla/hlo/transforms/host_offloader.h | 15 ++- xla/hlo/transforms/host_offloader_test.cc | 37 +++++++ 3 files changed, 148 insertions(+), 18 deletions(-) diff --git a/xla/hlo/transforms/host_offloader.cc b/xla/hlo/transforms/host_offloader.cc index 6ed6c9c2d55b3..c823618a053cf 100644 --- a/xla/hlo/transforms/host_offloader.cc +++ b/xla/hlo/transforms/host_offloader.cc @@ -138,7 +138,7 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( absl::flat_hash_set slices_to_dynamify; absl::flat_hash_set custom_calls_to_insert_copies_before; std::vector buffers_to_set_to_host_memory; - std::vector dynamic_update_slices; + // std::vector move_to_host_dynamic_update_slices; HloInstruction* starting_instruction = starting_instruction_and_index.instruction; std::queue queue; @@ -165,12 +165,17 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( custom_calls_to_insert_copies_before.insert(instruction); continue; } else if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) { + // Save every DynamicUpdateSlice we see to process after all host memory + // space propagation is done. + if (!absl::c_linear_search(dynamic_update_slices_seen_, instruction)) { + dynamic_update_slices_seen_.push_back(instruction); + } if (instruction == starting_instruction) { - dynamic_update_slices.push_back(instruction); - } else { - // The input to this DynamicUpdateSlice is already in host memory. Save - // this so that we don't try to create an AllocateBuffer later. - dynamic_update_slices_already_allocated_.insert(instruction); + // This DynamicUpdateSlice's update operand had a MoveToHost annotation. + if (!absl::c_linear_search(dynamic_update_slices_seen_with_annotation_, + instruction)) { + dynamic_update_slices_seen_with_annotation_.push_back(instruction); + } } } else if (host_offload_utils::IsValidDuringPureMemoryOffload( instruction)) { @@ -240,12 +245,30 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( instruction->name()); SetHostComputeFrontendAttribute(*instruction); } - if (!already_saved_buffer) { - // Save buffer to be set to host memory. - VLOG(5) << "Saving " << instruction_and_shape_index.ToString() - << " to be set to host memory."; - buffers_to_set_to_host_memory.push_back(instruction_and_shape_index); + const HloInstruction* instruction = + instruction_and_shape_index.instruction; + bool set_as_host_memory = true; + if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) { + // We'll do DUSes later. + set_as_host_memory = false; + + // // At this point, at least one of our operands must be in host memory + // // space. Only if the base operand is should we set the + // // DynamicUpdateSlice as in host memory. + // set_as_host_memory = + // DynamicUpdateSliceOperandIsInHostMemory(instruction, + // buffers_to_set_to_host_memory); + // LOG(INFO) << "Setting DUS " << instruction->name() << " as host + // memory? " << set_as_host_memory; + } + + if (set_as_host_memory) { + // Save buffer to be set to host memory. + VLOG(5) << "Saving " << instruction_and_shape_index.ToString() + << " to be set to host memory."; + buffers_to_set_to_host_memory.push_back(instruction_and_shape_index); + } } // Check if this path ends at the output of the entry computation. @@ -283,12 +306,12 @@ absl::StatusOr HostOffloader::WalkDownHostMemoryOffloadPaths( buffers_to_set_to_host_memory, Layout::kHostMemorySpace); changed = changed || set_buffers_changed; - for (HloInstruction* dus : dynamic_update_slices) { - // Create a host AllocateBuffer instruction which this DynamicUpdateSlice - // will update-slice into. - TF_RETURN_IF_ERROR(CreateAllocateBufferForDynamicUpdateSlice(dus)); - changed = true; - } + // for (HloInstruction* dus : move_to_host_dynamic_update_slices) { + // // Create a host AllocateBuffer instruction which this DynamicUpdateSlice + // // will update-slice into. + // TF_RETURN_IF_ERROR(CreateAllocateBufferForDynamicUpdateSlice(dus)); + // changed = true; + // } if (insert_copy_before) { const auto predecessors = @@ -1060,6 +1083,56 @@ absl::StatusOr HostOffloader::ProcessNextMoveToHostInstr( return false; } +absl::StatusOr HostOffloader::HandleDynamicUpdateSlices() { + bool changed = false; + for (HloInstruction* dus : dynamic_update_slices_seen_) { + // Look at the memory spaces of the operand and update. These should have + // already been updated by host memory space propagation. Maybe update this + // DynamicUpdateSlice depending on what memory space they are and whether or + // not the update had a MoveToHost annotation. + const int64_t operand_memory_space = + dus->operand(0)->shape().layout().memory_space(); + const int64_t update_memory_space = + dus->operand(1)->shape().layout().memory_space(); + const bool host_to_host = update_memory_space == Layout::kHostMemorySpace && + operand_memory_space == Layout::kHostMemorySpace; + const bool host_to_device = + update_memory_space == Layout::kHostMemorySpace && + operand_memory_space == Layout::kDefaultMemorySpace; + const bool device_to_host = + update_memory_space == Layout::kDefaultMemorySpace && + operand_memory_space == Layout::kHostMemorySpace; + const bool device_to_device = + update_memory_space == Layout::kDefaultMemorySpace && + operand_memory_space == Layout::kDefaultMemorySpace; + if (host_to_device) { + // This is only supported via host compute. + SetHostComputeFrontendAttribute(*dus); + changed = true; + } else if (host_to_host) { + // Host to host. Execute as host compute. Also set as host memory space. + SetHostComputeFrontendAttribute(*dus); + SetMemorySpace(dus->mutable_shape(), Layout::kHostMemorySpace); + changed = true; + } else if (device_to_host) { + // Device to host. + SetMemorySpace(dus->mutable_shape(), Layout::kHostMemorySpace); + changed = true; + } else if (device_to_device) { + // Device to device. + if (absl::c_linear_search(dynamic_update_slices_seen_with_annotation_, + dus)) { + // This DynamicUpdateSlice is used as a pure memory offload. Create a + // host AllocateBuffer instruction which this DynamicUpdateSlice will + // update-slice into. + TF_RETURN_IF_ERROR(CreateAllocateBufferForDynamicUpdateSlice(dus)); + changed = true; + } + } + } + return changed; +} + absl::StatusOr HostOffloader::Run( HloModule* module, const absl::flat_hash_set& execution_threads) { @@ -1098,6 +1171,13 @@ absl::StatusOr HostOffloader::Run( } } while (changed_in_loop); + // For other ops, we can immediately know whether or not they need to be + // converted to host compute. DynamicUpdateSlices are different because they + // have multiple operands. Only after finishing all host memory space + // propagation can we know what to do with the DynamicUpdateSlice. + TF_ASSIGN_OR_RETURN(bool any_dus_changed, HandleDynamicUpdateSlices()); + changed = changed || any_dus_changed; + // Remove all MoveToDevice custom calls. for (HloComputation* computation : module->MakeComputationPostOrder(execution_threads)) { diff --git a/xla/hlo/transforms/host_offloader.h b/xla/hlo/transforms/host_offloader.h index 5055aa15f10a8..94ebf4a0a4fb5 100644 --- a/xla/hlo/transforms/host_offloader.h +++ b/xla/hlo/transforms/host_offloader.h @@ -84,6 +84,18 @@ class HostOffloader : public HloModulePass { absl::flat_hash_set already_inserted_copy_before_; + // DynamicUpdateSlices are a bit special because they are the only op which + // has multiple operands that host memory offloading supports. As a result, + // different memory propagation paths can pass through the same + // DynamicUpdateSlice. These track which paths have been seen. + std::vector dynamic_update_slices_seen_; + std::vector dynamic_update_slices_seen_with_annotation_; + + // Maybe set DynamicUpdateSlice as host compute. Also maybe convert + // broadcast(0) to an "AllocateBuffer". Should be called only after all host + // memory propagation is done. Returns true if the module was changed. + absl::StatusOr HandleDynamicUpdateSlices(); + // Sometimes previous transformations turn a DynamicSlice into a Slice. Since // we're doing a DMA between the host and device, we need to turn the Slice // back into a DynamicSlice. @@ -122,7 +134,8 @@ class HostOffloader : public HloModulePass { // DynamicUpdateSlices which write into host memory must have their // destination buffer allocated on the host. This function creates the - // allocation and updates all positions to have host memory space. + // allocation and updates all positions to have host memory space. Note that + // this also sets the DynamicUpdateSlice's memory space to host memory. absl::Status CreateAllocateBufferForDynamicUpdateSlice( HloInstruction* dynamic_update_slice); diff --git a/xla/hlo/transforms/host_offloader_test.cc b/xla/hlo/transforms/host_offloader_test.cc index 0ca2d5d2dad5e..7670c1910f2ba 100644 --- a/xla/hlo/transforms/host_offloader_test.cc +++ b/xla/hlo/transforms/host_offloader_test.cc @@ -4340,6 +4340,43 @@ TEST_F(HostOffloaderTest, DynamicSliceOnHostMemoryIndexCopied) { EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(tanh)); } +TEST_F(HostOffloaderTest, DynamicUpdateSliceAllGatherDecomposer) { + const absl::string_view hlo_string = R"( +HloModule jit_f, entry_computation_layout={(s32[4]{0:T(256)S(5)})->s32[32]{0:T(256)S(5)}}, num_partitions=8 + +add { + x = s32[]{:T(256)} parameter(0) + y = s32[]{:T(256)} parameter(1) + ROOT add.1 = s32[]{:T(256)} add(x, y) +} + +ENTRY main.5_spmd { + constant.3 = s32[]{:T(256)} constant(0) + broadcast = s32[32]{0:T(256)} broadcast(constant.3), dimensions={} + param = s32[4]{0:T(256)} parameter(0), sharding={devices=[8]<=[8]}, metadata={op_name="x"} + replica-id = u32[]{:T(256)} replica-id() + constant.1 = u32[]{:T(256)} constant(8) + multiply = u32[]{:T(256)} multiply(replica-id, constant.1) + partition-id.1 = u32[]{:T(256)} partition-id() + add = u32[]{:T(256)} add(multiply, partition-id.1) + constant.2 = u32[]{:T(256)} constant(4) + multiply.1 = u32[]{:T(256)} multiply(add, constant.2) + dynamic-update-slice = s32[32]{0:T(256)} dynamic-update-slice(broadcast, param, multiply.1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"3","ones":"0","bitwidth":"32"}],"is_index_aligned":[false]},"used_scoped_memory_configs":[]} + all-reduce = s32[32]{0:T(256)} all-reduce(dynamic-update-slice), channel_id=1, replica_groups=[1,8]<=[8], use_global_device_ids=true, to_apply=add + ROOT custom-call.1 = s32[32]{0:T(256)} custom-call(all-reduce), custom_call_target="MoveToHost" +} +)"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + ParseAndReturnVerifiedModule(hlo_string)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get())); + EXPECT_TRUE(changed); + VLOG(1) << module->ToString(); + HloInstruction* dynamic_update_slice = + FindInstruction(module.get(), "dynamic-update-slice"); + EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(dynamic_update_slice)); +} + } // namespace } // namespace xla