Skip to content

Commit

Permalink
Change HostOffloader to mark every DynamicUpdateSlice which operates …
Browse files Browse the repository at this point in the history
…on host memory as host compute. This, of course, excludes DynamicUpdateSlices which are used for host offloading DMAs.

PiperOrigin-RevId: 708403064
  • Loading branch information
SandSnip3r authored and Google-ML-Automation committed Jan 18, 2025
1 parent cd2aac8 commit 8d70d23
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 18 deletions.
114 changes: 97 additions & 17 deletions xla/hlo/transforms/host_offloader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
absl::flat_hash_set<HloInstruction*> slices_to_dynamify;
absl::flat_hash_set<HloInstruction*> custom_calls_to_insert_copies_before;
std::vector<InstructionAndShapeIndex> buffers_to_set_to_host_memory;
std::vector<HloInstruction*> dynamic_update_slices;
// std::vector<HloInstruction*> move_to_host_dynamic_update_slices;
HloInstruction* starting_instruction =
starting_instruction_and_index.instruction;
std::queue<InstructionAndShapeIndex> queue;
Expand All @@ -165,12 +165,17 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
custom_calls_to_insert_copies_before.insert(instruction);
continue;
} else if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) {
// Save every DynamicUpdateSlice we see to process after all host memory
// space propagation is done.
if (!absl::c_linear_search(dynamic_update_slices_seen_, instruction)) {
dynamic_update_slices_seen_.push_back(instruction);
}
if (instruction == starting_instruction) {
dynamic_update_slices.push_back(instruction);
} else {
// The input to this DynamicUpdateSlice is already in host memory. Save
// this so that we don't try to create an AllocateBuffer later.
dynamic_update_slices_already_allocated_.insert(instruction);
// This DynamicUpdateSlice's update operand had a MoveToHost annotation.
if (!absl::c_linear_search(dynamic_update_slices_seen_with_annotation_,
instruction)) {
dynamic_update_slices_seen_with_annotation_.push_back(instruction);
}
}
} else if (host_offload_utils::IsValidDuringPureMemoryOffload(
instruction)) {
Expand Down Expand Up @@ -240,12 +245,30 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
instruction->name());
SetHostComputeFrontendAttribute(*instruction);
}

if (!already_saved_buffer) {
// Save buffer to be set to host memory.
VLOG(5) << "Saving " << instruction_and_shape_index.ToString()
<< " to be set to host memory.";
buffers_to_set_to_host_memory.push_back(instruction_and_shape_index);
const HloInstruction* instruction =
instruction_and_shape_index.instruction;
bool set_as_host_memory = true;
if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) {
// We'll do DUSes later.
set_as_host_memory = false;

// // At this point, at least one of our operands must be in host memory
// // space. Only if the base operand is should we set the
// // DynamicUpdateSlice as in host memory.
// set_as_host_memory =
// DynamicUpdateSliceOperandIsInHostMemory(instruction,
// buffers_to_set_to_host_memory);
// LOG(INFO) << "Setting DUS " << instruction->name() << " as host
// memory? " << set_as_host_memory;
}

if (set_as_host_memory) {
// Save buffer to be set to host memory.
VLOG(5) << "Saving " << instruction_and_shape_index.ToString()
<< " to be set to host memory.";
buffers_to_set_to_host_memory.push_back(instruction_and_shape_index);
}
}

// Check if this path ends at the output of the entry computation.
Expand Down Expand Up @@ -283,12 +306,12 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
buffers_to_set_to_host_memory, Layout::kHostMemorySpace);
changed = changed || set_buffers_changed;

for (HloInstruction* dus : dynamic_update_slices) {
// Create a host AllocateBuffer instruction which this DynamicUpdateSlice
// will update-slice into.
TF_RETURN_IF_ERROR(CreateAllocateBufferForDynamicUpdateSlice(dus));
changed = true;
}
// for (HloInstruction* dus : move_to_host_dynamic_update_slices) {
// // Create a host AllocateBuffer instruction which this DynamicUpdateSlice
// // will update-slice into.
// TF_RETURN_IF_ERROR(CreateAllocateBufferForDynamicUpdateSlice(dus));
// changed = true;
// }

if (insert_copy_before) {
const auto predecessors =
Expand Down Expand Up @@ -1060,6 +1083,56 @@ absl::StatusOr<bool> HostOffloader::ProcessNextMoveToHostInstr(
return false;
}

absl::StatusOr<bool> HostOffloader::HandleDynamicUpdateSlices() {
bool changed = false;
for (HloInstruction* dus : dynamic_update_slices_seen_) {
// Look at the memory spaces of the operand and update. These should have
// already been updated by host memory space propagation. Maybe update this
// DynamicUpdateSlice depending on what memory space they are and whether or
// not the update had a MoveToHost annotation.
const int64_t operand_memory_space =
dus->operand(0)->shape().layout().memory_space();
const int64_t update_memory_space =
dus->operand(1)->shape().layout().memory_space();
const bool host_to_host = update_memory_space == Layout::kHostMemorySpace &&
operand_memory_space == Layout::kHostMemorySpace;
const bool host_to_device =
update_memory_space == Layout::kHostMemorySpace &&
operand_memory_space == Layout::kDefaultMemorySpace;
const bool device_to_host =
update_memory_space == Layout::kDefaultMemorySpace &&
operand_memory_space == Layout::kHostMemorySpace;
const bool device_to_device =
update_memory_space == Layout::kDefaultMemorySpace &&
operand_memory_space == Layout::kDefaultMemorySpace;
if (host_to_device) {
// This is only supported via host compute.
SetHostComputeFrontendAttribute(*dus);
changed = true;
} else if (host_to_host) {
// Host to host. Execute as host compute. Also set as host memory space.
SetHostComputeFrontendAttribute(*dus);
SetMemorySpace(dus->mutable_shape(), Layout::kHostMemorySpace);
changed = true;
} else if (device_to_host) {
// Device to host.
SetMemorySpace(dus->mutable_shape(), Layout::kHostMemorySpace);
changed = true;
} else if (device_to_device) {
// Device to device.
if (absl::c_linear_search(dynamic_update_slices_seen_with_annotation_,
dus)) {
// This DynamicUpdateSlice is used as a pure memory offload. Create a
// host AllocateBuffer instruction which this DynamicUpdateSlice will
// update-slice into.
TF_RETURN_IF_ERROR(CreateAllocateBufferForDynamicUpdateSlice(dus));
changed = true;
}
}
}
return changed;
}

absl::StatusOr<bool> HostOffloader::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
Expand Down Expand Up @@ -1098,6 +1171,13 @@ absl::StatusOr<bool> HostOffloader::Run(
}
} while (changed_in_loop);

// For other ops, we can immediately know whether or not they need to be
// converted to host compute. DynamicUpdateSlices are different because they
// have multiple operands. Only after finishing all host memory space
// propagation can we know what to do with the DynamicUpdateSlice.
TF_ASSIGN_OR_RETURN(bool any_dus_changed, HandleDynamicUpdateSlices());
changed = changed || any_dus_changed;

// Remove all MoveToDevice custom calls.
for (HloComputation* computation :
module->MakeComputationPostOrder(execution_threads)) {
Expand Down
15 changes: 14 additions & 1 deletion xla/hlo/transforms/host_offloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ class HostOffloader : public HloModulePass {
absl::flat_hash_set<host_offload_utils::InstructionAndShapeIndex>
already_inserted_copy_before_;

// DynamicUpdateSlices are a bit special because they are the only op which
// has multiple operands that host memory offloading supports. As a result,
// different memory propagation paths can pass through the same
// DynamicUpdateSlice. These track which paths have been seen.
std::vector<HloInstruction*> dynamic_update_slices_seen_;
std::vector<HloInstruction*> dynamic_update_slices_seen_with_annotation_;

// Maybe set DynamicUpdateSlice as host compute. Also maybe convert
// broadcast(0) to an "AllocateBuffer". Should be called only after all host
// memory propagation is done. Returns true if the module was changed.
absl::StatusOr<bool> HandleDynamicUpdateSlices();

// Sometimes previous transformations turn a DynamicSlice into a Slice. Since
// we're doing a DMA between the host and device, we need to turn the Slice
// back into a DynamicSlice.
Expand Down Expand Up @@ -122,7 +134,8 @@ class HostOffloader : public HloModulePass {

// DynamicUpdateSlices which write into host memory must have their
// destination buffer allocated on the host. This function creates the
// allocation and updates all positions to have host memory space.
// allocation and updates all positions to have host memory space. Note that
// this also sets the DynamicUpdateSlice's memory space to host memory.
absl::Status CreateAllocateBufferForDynamicUpdateSlice(
HloInstruction* dynamic_update_slice);

Expand Down
37 changes: 37 additions & 0 deletions xla/hlo/transforms/host_offloader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4340,6 +4340,43 @@ TEST_F(HostOffloaderTest, DynamicSliceOnHostMemoryIndexCopied) {
EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(tanh));
}

TEST_F(HostOffloaderTest, DynamicUpdateSliceAllGatherDecomposer) {
const absl::string_view hlo_string = R"(
HloModule jit_f, entry_computation_layout={(s32[4]{0:T(256)S(5)})->s32[32]{0:T(256)S(5)}}, num_partitions=8
add {
x = s32[]{:T(256)} parameter(0)
y = s32[]{:T(256)} parameter(1)
ROOT add.1 = s32[]{:T(256)} add(x, y)
}
ENTRY main.5_spmd {
constant.3 = s32[]{:T(256)} constant(0)
broadcast = s32[32]{0:T(256)} broadcast(constant.3), dimensions={}
param = s32[4]{0:T(256)} parameter(0), sharding={devices=[8]<=[8]}, metadata={op_name="x"}
replica-id = u32[]{:T(256)} replica-id()
constant.1 = u32[]{:T(256)} constant(8)
multiply = u32[]{:T(256)} multiply(replica-id, constant.1)
partition-id.1 = u32[]{:T(256)} partition-id()
add = u32[]{:T(256)} add(multiply, partition-id.1)
constant.2 = u32[]{:T(256)} constant(4)
multiply.1 = u32[]{:T(256)} multiply(add, constant.2)
dynamic-update-slice = s32[32]{0:T(256)} dynamic-update-slice(broadcast, param, multiply.1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"3","ones":"0","bitwidth":"32"}],"is_index_aligned":[false]},"used_scoped_memory_configs":[]}
all-reduce = s32[32]{0:T(256)} all-reduce(dynamic-update-slice), channel_id=1, replica_groups=[1,8]<=[8], use_global_device_ids=true, to_apply=add
ROOT custom-call.1 = s32[32]{0:T(256)} custom-call(all-reduce), custom_call_target="MoveToHost"
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get()));
EXPECT_TRUE(changed);
VLOG(1) << module->ToString();
HloInstruction* dynamic_update_slice =
FindInstruction(module.get(), "dynamic-update-slice");
EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(dynamic_update_slice));
}

} // namespace

} // namespace xla

0 comments on commit 8d70d23

Please sign in to comment.