Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change HostOffloader to mark every DynamicUpdateSlice which operates on host memory as host compute. This, of course, excludes DynamicUpdateSlices which are used for host offloading DMAs. #20790

Merged
merged 1 commit into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 97 additions & 17 deletions xla/hlo/transforms/host_offloader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
absl::flat_hash_set<HloInstruction*> slices_to_dynamify;
absl::flat_hash_set<HloInstruction*> custom_calls_to_insert_copies_before;
std::vector<InstructionAndShapeIndex> buffers_to_set_to_host_memory;
std::vector<HloInstruction*> dynamic_update_slices;
// std::vector<HloInstruction*> move_to_host_dynamic_update_slices;
HloInstruction* starting_instruction =
starting_instruction_and_index.instruction;
std::queue<InstructionAndShapeIndex> queue;
Expand All @@ -165,12 +165,17 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
custom_calls_to_insert_copies_before.insert(instruction);
continue;
} else if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) {
// Save every DynamicUpdateSlice we see to process after all host memory
// space propagation is done.
if (!absl::c_linear_search(dynamic_update_slices_seen_, instruction)) {
dynamic_update_slices_seen_.push_back(instruction);
}
if (instruction == starting_instruction) {
dynamic_update_slices.push_back(instruction);
} else {
// The input to this DynamicUpdateSlice is already in host memory. Save
// this so that we don't try to create an AllocateBuffer later.
dynamic_update_slices_already_allocated_.insert(instruction);
// This DynamicUpdateSlice's update operand had a MoveToHost annotation.
if (!absl::c_linear_search(dynamic_update_slices_seen_with_annotation_,
instruction)) {
dynamic_update_slices_seen_with_annotation_.push_back(instruction);
}
}
} else if (host_offload_utils::IsValidDuringPureMemoryOffload(
instruction)) {
Expand Down Expand Up @@ -240,12 +245,30 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
instruction->name());
SetHostComputeFrontendAttribute(*instruction);
}

if (!already_saved_buffer) {
// Save buffer to be set to host memory.
VLOG(5) << "Saving " << instruction_and_shape_index.ToString()
<< " to be set to host memory.";
buffers_to_set_to_host_memory.push_back(instruction_and_shape_index);
const HloInstruction* instruction =
instruction_and_shape_index.instruction;
bool set_as_host_memory = true;
if (instruction->opcode() == HloOpcode::kDynamicUpdateSlice) {
// We'll do DUSes later.
set_as_host_memory = false;

// // At this point, at least one of our operands must be in host memory
// // space. Only if the base operand is should we set the
// // DynamicUpdateSlice as in host memory.
// set_as_host_memory =
// DynamicUpdateSliceOperandIsInHostMemory(instruction,
// buffers_to_set_to_host_memory);
// LOG(INFO) << "Setting DUS " << instruction->name() << " as host
// memory? " << set_as_host_memory;
}

if (set_as_host_memory) {
// Save buffer to be set to host memory.
VLOG(5) << "Saving " << instruction_and_shape_index.ToString()
<< " to be set to host memory.";
buffers_to_set_to_host_memory.push_back(instruction_and_shape_index);
}
}

// Check if this path ends at the output of the entry computation.
Expand Down Expand Up @@ -283,12 +306,12 @@ absl::StatusOr<bool> HostOffloader::WalkDownHostMemoryOffloadPaths(
buffers_to_set_to_host_memory, Layout::kHostMemorySpace);
changed = changed || set_buffers_changed;

for (HloInstruction* dus : dynamic_update_slices) {
// Create a host AllocateBuffer instruction which this DynamicUpdateSlice
// will update-slice into.
TF_RETURN_IF_ERROR(CreateAllocateBufferForDynamicUpdateSlice(dus));
changed = true;
}
// for (HloInstruction* dus : move_to_host_dynamic_update_slices) {
// // Create a host AllocateBuffer instruction which this DynamicUpdateSlice
// // will update-slice into.
// TF_RETURN_IF_ERROR(CreateAllocateBufferForDynamicUpdateSlice(dus));
// changed = true;
// }

if (insert_copy_before) {
const auto predecessors =
Expand Down Expand Up @@ -1060,6 +1083,56 @@ absl::StatusOr<bool> HostOffloader::ProcessNextMoveToHostInstr(
return false;
}

absl::StatusOr<bool> HostOffloader::HandleDynamicUpdateSlices() {
bool changed = false;
for (HloInstruction* dus : dynamic_update_slices_seen_) {
// Look at the memory spaces of the operand and update. These should have
// already been updated by host memory space propagation. Maybe update this
// DynamicUpdateSlice depending on what memory space they are and whether or
// not the update had a MoveToHost annotation.
const int64_t operand_memory_space =
dus->operand(0)->shape().layout().memory_space();
const int64_t update_memory_space =
dus->operand(1)->shape().layout().memory_space();
const bool host_to_host = update_memory_space == Layout::kHostMemorySpace &&
operand_memory_space == Layout::kHostMemorySpace;
const bool host_to_device =
update_memory_space == Layout::kHostMemorySpace &&
operand_memory_space == Layout::kDefaultMemorySpace;
const bool device_to_host =
update_memory_space == Layout::kDefaultMemorySpace &&
operand_memory_space == Layout::kHostMemorySpace;
const bool device_to_device =
update_memory_space == Layout::kDefaultMemorySpace &&
operand_memory_space == Layout::kDefaultMemorySpace;
if (host_to_device) {
// This is only supported via host compute.
SetHostComputeFrontendAttribute(*dus);
changed = true;
} else if (host_to_host) {
// Host to host. Execute as host compute. Also set as host memory space.
SetHostComputeFrontendAttribute(*dus);
SetMemorySpace(dus->mutable_shape(), Layout::kHostMemorySpace);
changed = true;
} else if (device_to_host) {
// Device to host.
SetMemorySpace(dus->mutable_shape(), Layout::kHostMemorySpace);
changed = true;
} else if (device_to_device) {
// Device to device.
if (absl::c_linear_search(dynamic_update_slices_seen_with_annotation_,
dus)) {
// This DynamicUpdateSlice is used as a pure memory offload. Create a
// host AllocateBuffer instruction which this DynamicUpdateSlice will
// update-slice into.
TF_RETURN_IF_ERROR(CreateAllocateBufferForDynamicUpdateSlice(dus));
changed = true;
}
}
}
return changed;
}

absl::StatusOr<bool> HostOffloader::Run(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
Expand Down Expand Up @@ -1098,6 +1171,13 @@ absl::StatusOr<bool> HostOffloader::Run(
}
} while (changed_in_loop);

// For other ops, we can immediately know whether or not they need to be
// converted to host compute. DynamicUpdateSlices are different because they
// have multiple operands. Only after finishing all host memory space
// propagation can we know what to do with the DynamicUpdateSlice.
TF_ASSIGN_OR_RETURN(bool any_dus_changed, HandleDynamicUpdateSlices());
changed = changed || any_dus_changed;

// Remove all MoveToDevice custom calls.
for (HloComputation* computation :
module->MakeComputationPostOrder(execution_threads)) {
Expand Down
15 changes: 14 additions & 1 deletion xla/hlo/transforms/host_offloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,18 @@ class HostOffloader : public HloModulePass {
absl::flat_hash_set<host_offload_utils::InstructionAndShapeIndex>
already_inserted_copy_before_;

// DynamicUpdateSlices are a bit special because they are the only op which
// has multiple operands that host memory offloading supports. As a result,
// different memory propagation paths can pass through the same
// DynamicUpdateSlice. These track which paths have been seen.
std::vector<HloInstruction*> dynamic_update_slices_seen_;
std::vector<HloInstruction*> dynamic_update_slices_seen_with_annotation_;

// Maybe set DynamicUpdateSlice as host compute. Also maybe convert
// broadcast(0) to an "AllocateBuffer". Should be called only after all host
// memory propagation is done. Returns true if the module was changed.
absl::StatusOr<bool> HandleDynamicUpdateSlices();

// Sometimes previous transformations turn a DynamicSlice into a Slice. Since
// we're doing a DMA between the host and device, we need to turn the Slice
// back into a DynamicSlice.
Expand Down Expand Up @@ -122,7 +134,8 @@ class HostOffloader : public HloModulePass {

// DynamicUpdateSlices which write into host memory must have their
// destination buffer allocated on the host. This function creates the
// allocation and updates all positions to have host memory space.
// allocation and updates all positions to have host memory space. Note that
// this also sets the DynamicUpdateSlice's memory space to host memory.
absl::Status CreateAllocateBufferForDynamicUpdateSlice(
HloInstruction* dynamic_update_slice);

Expand Down
37 changes: 37 additions & 0 deletions xla/hlo/transforms/host_offloader_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4340,6 +4340,43 @@ TEST_F(HostOffloaderTest, DynamicSliceOnHostMemoryIndexCopied) {
EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(tanh));
}

TEST_F(HostOffloaderTest, DynamicUpdateSliceAllGatherDecomposer) {
const absl::string_view hlo_string = R"(
HloModule jit_f, entry_computation_layout={(s32[4]{0:T(256)S(5)})->s32[32]{0:T(256)S(5)}}, num_partitions=8

add {
x = s32[]{:T(256)} parameter(0)
y = s32[]{:T(256)} parameter(1)
ROOT add.1 = s32[]{:T(256)} add(x, y)
}

ENTRY main.5_spmd {
constant.3 = s32[]{:T(256)} constant(0)
broadcast = s32[32]{0:T(256)} broadcast(constant.3), dimensions={}
param = s32[4]{0:T(256)} parameter(0), sharding={devices=[8]<=[8]}, metadata={op_name="x"}
replica-id = u32[]{:T(256)} replica-id()
constant.1 = u32[]{:T(256)} constant(8)
multiply = u32[]{:T(256)} multiply(replica-id, constant.1)
partition-id.1 = u32[]{:T(256)} partition-id()
add = u32[]{:T(256)} add(multiply, partition-id.1)
constant.2 = u32[]{:T(256)} constant(4)
multiply.1 = u32[]{:T(256)} multiply(add, constant.2)
dynamic-update-slice = s32[32]{0:T(256)} dynamic-update-slice(broadcast, param, multiply.1), backend_config={"flag_configs":[],"scoped_memory_configs":[],"indices_config":{"index_known_bits":[{"zeroes":"3","ones":"0","bitwidth":"32"}],"is_index_aligned":[false]},"used_scoped_memory_configs":[]}
all-reduce = s32[32]{0:T(256)} all-reduce(dynamic-update-slice), channel_id=1, replica_groups=[1,8]<=[8], use_global_device_ids=true, to_apply=add
ROOT custom-call.1 = s32[32]{0:T(256)} custom-call(all-reduce), custom_call_target="MoveToHost"
}
)";

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHostOffloader(module.get()));
EXPECT_TRUE(changed);
VLOG(1) << module->ToString();
HloInstruction* dynamic_update_slice =
FindInstruction(module.get(), "dynamic-update-slice");
EXPECT_TRUE(host_offload_utils::ComputeTypeIsHost(dynamic_update_slice));
}

} // namespace

} // namespace xla
Loading