Skip to content

Commit

Permalink
Support Consecutive Tasks in Open MoE LLM Leaderboard (#17)
Browse files Browse the repository at this point in the history
* fix:avoid redundant peer access enable

* fix:resource cleanup in every initialization

* fix:distinguish between initialization hooks and runtime hooks

* fix:reset pretrained method in `__exit__`
  • Loading branch information
lausannel authored Apr 29, 2024
1 parent 8ef1206 commit 912abad
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 26 deletions.
4 changes: 2 additions & 2 deletions core/memory/memory_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
#include <sys/sysinfo.h>
#include <unistd.h>

std::unique_ptr<HostMemoryPool> kHostMemoryPool = std::make_unique<HostMemoryPool>();
std::unique_ptr<DeviceMemoryPool> kDeviceMemoryPool = std::make_unique<DeviceMemoryPool>();
std::unique_ptr<HostMemoryPool> kHostMemoryPool(nullptr);
std::unique_ptr<DeviceMemoryPool> kDeviceMemoryPool(nullptr);

std::size_t GetTotalSystemMemory()
{
Expand Down
30 changes: 27 additions & 3 deletions core/memory/memory_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
#include "common/pytorch.h"
#include "utils/noncopyable.h"

#include <c10/core/CPUAllocator.h>
#include "utils/archer_logger.h"
#include "host_caching_allocator.h"
#include <c10/cuda/CUDACachingAllocator.h>
#include <mutex>
#include <unordered_map>
#include <unordered_set>
Expand Down Expand Up @@ -35,7 +37,16 @@ class HostMemoryPool : public noncopyable {
std::int64_t GetMemoryCapacity();

HostMemoryPool();
virtual ~HostMemoryPool() = default;
virtual ~HostMemoryPool()
{
auto allocator = c10::HostCachingAllocator::get();
for (auto& [key, data_ptr] : allocated_id_) {
if (data_ptr != nullptr) {
allocator->free(data_ptr);
}
}
allocated_id_.clear();
}

private:
std::unordered_map<std::uint64_t, void*> allocated_id_;
Expand All @@ -59,7 +70,20 @@ class DeviceMemoryPool : public noncopyable {
std::int64_t GetMemoryCapacity(const torch::Device& device);

DeviceMemoryPool();
virtual ~DeviceMemoryPool() = default;
virtual ~DeviceMemoryPool()
{
auto allocator = c10::cuda::CUDACachingAllocator::get();
for(auto &allocated_id : allocated_id_){
for (auto& [key, data_ptr] : allocated_id) {
if (data_ptr != nullptr) {
allocator->raw_deallocate(data_ptr);
}
}
}
allocated_id_.clear();
free_memory_.clear();
memory_capacity_.clear();
}

private:
std::vector<std::unordered_map<std::uint64_t, void*>> allocated_id_;
Expand Down
16 changes: 15 additions & 1 deletion core/prefetch/archer_prefetch_handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ ArcherPrefetchHandle::ArcherPrefetchHandle(const std::string& prefix,
kArcherTensorHandle = std::make_unique<ArcherTensorHandle>(prefix);
kTopologyHandle = std::make_unique<ArcherTopologyHandle>();
kTaskPool = std::make_unique<ArcherTaskPool>();
kDeviceMemoryPool = std::make_unique<DeviceMemoryPool>();
kHostMemoryPool = std::make_unique<HostMemoryPool>();
kDeviceMemoryPool->SetMemoryRatio(device_memory_ratio);
ARCHER_LOG_DEBUG("Free Device Memory ", kDeviceMemoryPool->GetFreeMemory(CUDA_DEVICE(0)));

Expand All @@ -40,7 +42,16 @@ ArcherPrefetchHandle::ArcherPrefetchHandle(const std::string& prefix,
cudaDeviceCanAccessPeer(&can_access, i, j);
if (can_access == 1) {
cudaSetDevice(i);
cudaDeviceEnablePeerAccess(j, 0);
cudaError_t status = cudaDeviceEnablePeerAccess(j, 0);
if (status == cudaErrorPeerAccessAlreadyEnabled){
ARCHER_LOG_INFO("Peer access already enabled between device ", i, j);
cudaGetLastError(); // clear error
} else if (status != cudaSuccess) {
ARCHER_LOG_ERROR("Failed to enable peer access between device ", i, j);
} else {
ARCHER_LOG_INFO("Enabled peer access between device ", i, j);
}

}
}
}
Expand All @@ -54,6 +65,9 @@ ArcherPrefetchHandle::~ArcherPrefetchHandle()
// served as a global manager for order of destruction
kTaskPool.reset();
kArcherTensorHandle.reset();
kTopologyHandle.reset();
kDeviceMemoryPool.reset();
kHostMemoryPool.reset();
}

void ArcherPrefetchHandle::AcquireTensor(std::uint64_t& request_id, torch::Tensor& buffer)
Expand Down
59 changes: 39 additions & 20 deletions moe_infinity/runtime/model_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from tqdm import tqdm

import moe_infinity.modeling_grok
from moe_infinity.ops.op_builder.prefetch import PrefetchBuilder
from moe_infinity.models import (
SyncSwitchTransformersSparseMLP,
Expand Down Expand Up @@ -267,11 +268,15 @@ def cast_classifier_decorator(orig_cast_classifier: Callable) -> Callable:
@functools.wraps(orig_cast_classifier)
def archer_cast_classifier(cls, *args, **kwargs):
orig_data_ptr = cls.classifier.weight.data.data_ptr()
self.offload_set.remove(cls.classifier.weight.data.data_ptr())
orig_cast_classifier(cls, *args, **kwargs)
new_data_ptr = cls.classifier.weight.data.data_ptr()
self.offload_set.add(cls.classifier.weight.data.data_ptr())
self.archer_engine.update_tensor_map(orig_data_ptr, new_data_ptr)
if orig_data_ptr in self.offload_set:
self.offload_set.remove(cls.classifier.weight.data.data_ptr())
orig_cast_classifier(cls, *args, **kwargs)
new_data_ptr = cls.classifier.weight.data.data_ptr()
self.offload_set.add(cls.classifier.weight.data.data_ptr())
self.archer_engine.update_tensor_map(orig_data_ptr, new_data_ptr)
else:
orig_cast_classifier(cls, *args, **kwargs)
self.offload_set.add(cls.classifier.weight.data.data_ptr())

return archer_cast_classifier

Expand Down Expand Up @@ -326,7 +331,8 @@ def archer_cast_classifier(cls, *args, **kwargs):
if hasattr(module, "reset_parameters"):
module._old_reset_parameters = module.reset_parameters
module.reset_parameters = do_nothing_decorator(module.reset_parameters)


transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._old_cast_classifier = transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._cast_classifier
transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._cast_classifier = cast_classifier_decorator(transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._cast_classifier)

transformers.models.switch_transformers.modeling_switch_transformers._old_sparse_mlp = (
Expand Down Expand Up @@ -588,17 +594,16 @@ def archer_from_pretrained(cls, *args, **kwargs):

return self

# clean up initialization hooks
def __exit__(self, exc_type, exc_value, traceback):
# self.cls._load_pretrained_model = self.cls._old_load_pretrained_model
# self.cls.from_pretrained = self.cls._old_from_pretrained
self.cls.__init__ = self.cls._old_init
self.cls.from_pretrained = self.cls._old_from_pretrained
torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply
torch.index_select = torch._old_index_select
torch.Tensor.index_select = torch.Tensor._old_index_select

self.cls.post_init = self.cls._old_post_init
PreTrainedModel.post_init = PreTrainedModel._old_post_init
# self.cls.config_class.from_pretrained = (
# self.cls.config_class._old_from_pretrained)
# transformers.modeling_utils.load_state_dict = (
# transformers.modeling_utils.old_load_state_dict)
torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply

for name, module in torch.nn.modules.__dict__.items():
if not isinstance(module, type):
Expand All @@ -619,13 +624,6 @@ def __exit__(self, exc_type, exc_value, traceback):
if hasattr(module, "reset_parameters"):
module.reset_parameters = module._old_reset_parameters

transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersSparseMLP = (
transformers.models.switch_transformers.modeling_switch_transformers._old_sparse_mlp
)
transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeSparseMLP = (
transformers.models.nllb_moe.modeling_nllb_moe._old_sparse_mlp
)

def get_topology(self, model):
name_lst = []
ret_dict = {}
Expand Down Expand Up @@ -959,3 +957,24 @@ def _post_forward_module_hook(module, input, output):
self.forward_hooks.append(
module.register_forward_hook(_post_forward_module_hook)
)

# clean runtime hooks
def clean_up(self):
transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._cast_classifier = transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._old_cast_classifier
transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersSparseMLP = (
transformers.models.switch_transformers.modeling_switch_transformers._old_sparse_mlp
)

transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeSparseMLP = (
transformers.models.nllb_moe.modeling_nllb_moe._old_sparse_mlp
)

transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock = (
transformers.models.mixtral.modeling_mixtral._old_sparse_mlp
)

moe_infinity.modeling_grok.modeling_grok1.MoeBlock = (
moe_infinity.modeling_grok.modeling_grok1._old_sparse_mlp
)


0 comments on commit 912abad

Please sign in to comment.