Skip to content

Commit

Permalink
fix rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Jan 12, 2025
1 parent 90ab596 commit 5921570
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 81 deletions.
8 changes: 8 additions & 0 deletions mlx/distributed/distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ class EmptyGroup : public GroupImpl {
throw std::runtime_error(
"Communication not implemented in an empty distributed group.");
}
void barrier() override {
throw std::runtime_error(
"Barrier not implemented in an empty distributed group.");
}
};

} // namespace detail
Expand All @@ -80,6 +84,10 @@ Group Group::split(int color, int key /* = -1 */) const {
return Group(group_->split(color, key));
}

void Group::barrier() {
return group_->barrier();
}

Group init(bool strict /* = false */) {
auto init_group = [strict]() {
auto default_group = mpi::init(strict);
Expand Down
1 change: 0 additions & 1 deletion mlx/distributed/distributed.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ struct Group {

void barrier();


private:
std::shared_ptr<detail::GroupImpl> group_{nullptr};
};
Expand Down
1 change: 1 addition & 0 deletions mlx/distributed/distributed_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class GroupImpl {
virtual void all_gather(const array& input, array& output) = 0;
virtual void send(const array& input, int dst) = 0;
virtual void recv(array& out, int src) = 0;
virtual void barrier() = 0;
};

/* Return the communication stream. */
Expand Down
36 changes: 1 addition & 35 deletions mlx/distributed/mpi/mpi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ class MPIGroup : public GroupImpl {
&status);
}

void barrier() {
void barrier() override {
mpi().barrier(comm_);
}

Expand All @@ -332,40 +332,6 @@ class MPIGroup : public GroupImpl {
int size_;
};

MPI_Comm to_comm(Group& group) {
return std::static_pointer_cast<MPIGroupImpl>(group.raw_group())->comm();
}

} // namespace

int Group::rank() {
return std::static_pointer_cast<MPIGroupImpl>(group_)->rank();
}

int Group::size() {
return std::static_pointer_cast<MPIGroupImpl>(group_)->size();
}

Group Group::split(int color, int key) {
auto mpi_group = std::static_pointer_cast<MPIGroupImpl>(group_);

key = (key < 0) ? rank() : key;

MPI_Comm new_comm;
int result = mpi().comm_split(mpi_group->comm(), color, key, &new_comm);
if (result != MPI_SUCCESS) {
throw std::runtime_error("MPI could not split this group");
}

return Group(std::make_shared<MPIGroupImpl>(new_comm, false));
}

void Group::barrier() {
auto mpi_group = std::static_pointer_cast<MPIGroupImpl>(group_);
mpi_group->barrier();
}

>>>>>>> c3ccd4919 (Add MPI barrier)
bool is_available() {
return mpi().is_available();
}
Expand Down
44 changes: 0 additions & 44 deletions mlx/distributed/no_distributed.cpp

This file was deleted.

5 changes: 4 additions & 1 deletion python/src/distributed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ void init_distributed(nb::module_& parent_module) {
key (int, optional): A key to optionally change the rank ordering
of the processes.
)pbdoc")
.def("barrier", &distributed::Group::barrier, "Make a synhronization point for all nodes in the group");
.def(
"barrier",
&mx::distributed::Group::barrier,
"Make a synhronization point for all nodes in the group");

m.def(
"is_available",
Expand Down

0 comments on commit 5921570

Please sign in to comment.