diff --git a/core/memory/memory_pool.cpp b/core/memory/memory_pool.cpp index 3c57db3..a03f4d4 100644 --- a/core/memory/memory_pool.cpp +++ b/core/memory/memory_pool.cpp @@ -17,8 +17,8 @@ #include #include -std::unique_ptr kHostMemoryPool = std::make_unique(); -std::unique_ptr kDeviceMemoryPool = std::make_unique(); +std::unique_ptr kHostMemoryPool(nullptr); +std::unique_ptr kDeviceMemoryPool(nullptr); std::size_t GetTotalSystemMemory() { diff --git a/core/memory/memory_pool.h b/core/memory/memory_pool.h index 972978f..02424c3 100644 --- a/core/memory/memory_pool.h +++ b/core/memory/memory_pool.h @@ -7,7 +7,9 @@ #include "common/pytorch.h" #include "utils/noncopyable.h" -#include +#include "utils/archer_logger.h" +#include "host_caching_allocator.h" +#include #include #include #include @@ -35,7 +37,16 @@ class HostMemoryPool : public noncopyable { std::int64_t GetMemoryCapacity(); HostMemoryPool(); - virtual ~HostMemoryPool() = default; + virtual ~HostMemoryPool() + { + auto allocator = c10::HostCachingAllocator::get(); + for (auto& [key, data_ptr] : allocated_id_) { + if (data_ptr != nullptr) { + allocator->free(data_ptr); + } + } + allocated_id_.clear(); + } private: std::unordered_map allocated_id_; @@ -59,7 +70,20 @@ class DeviceMemoryPool : public noncopyable { std::int64_t GetMemoryCapacity(const torch::Device& device); DeviceMemoryPool(); - virtual ~DeviceMemoryPool() = default; + virtual ~DeviceMemoryPool() + { + auto allocator = c10::cuda::CUDACachingAllocator::get(); + for(auto &allocated_id : allocated_id_){ + for (auto& [key, data_ptr] : allocated_id) { + if (data_ptr != nullptr) { + allocator->raw_deallocate(data_ptr); + } + } + } + allocated_id_.clear(); + free_memory_.clear(); + memory_capacity_.clear(); + } private: std::vector> allocated_id_; diff --git a/core/prefetch/archer_prefetch_handle.cpp b/core/prefetch/archer_prefetch_handle.cpp index a107c50..5960b10 100644 --- a/core/prefetch/archer_prefetch_handle.cpp +++ b/core/prefetch/archer_prefetch_handle.cpp @@ -22,6 +22,8 @@ ArcherPrefetchHandle::ArcherPrefetchHandle(const std::string& prefix, kArcherTensorHandle = std::make_unique(prefix); kTopologyHandle = std::make_unique(); kTaskPool = std::make_unique(); + kDeviceMemoryPool = std::make_unique(); + kHostMemoryPool = std::make_unique(); kDeviceMemoryPool->SetMemoryRatio(device_memory_ratio); ARCHER_LOG_DEBUG("Free Device Memory ", kDeviceMemoryPool->GetFreeMemory(CUDA_DEVICE(0))); @@ -40,7 +42,16 @@ ArcherPrefetchHandle::ArcherPrefetchHandle(const std::string& prefix, cudaDeviceCanAccessPeer(&can_access, i, j); if (can_access == 1) { cudaSetDevice(i); - cudaDeviceEnablePeerAccess(j, 0); + cudaError_t status = cudaDeviceEnablePeerAccess(j, 0); + if (status == cudaErrorPeerAccessAlreadyEnabled){ + ARCHER_LOG_INFO("Peer access already enabled between device ", i, j); + cudaGetLastError(); // clear error + } else if (status != cudaSuccess) { + ARCHER_LOG_ERROR("Failed to enable peer access between device ", i, j); + } else { + ARCHER_LOG_INFO("Enabled peer access between device ", i, j); + } + } } } @@ -54,6 +65,9 @@ ArcherPrefetchHandle::~ArcherPrefetchHandle() // served as a global manager for order of destruction kTaskPool.reset(); kArcherTensorHandle.reset(); + kTopologyHandle.reset(); + kDeviceMemoryPool.reset(); + kHostMemoryPool.reset(); } void ArcherPrefetchHandle::AcquireTensor(std::uint64_t& request_id, torch::Tensor& buffer) diff --git a/moe_infinity/runtime/model_offload.py b/moe_infinity/runtime/model_offload.py index f9b2ffc..0cd1648 100644 --- a/moe_infinity/runtime/model_offload.py +++ b/moe_infinity/runtime/model_offload.py @@ -16,6 +16,7 @@ from tqdm import tqdm +import moe_infinity.modeling_grok from moe_infinity.ops.op_builder.prefetch import PrefetchBuilder from moe_infinity.models import ( SyncSwitchTransformersSparseMLP, @@ -267,11 +268,15 @@ def cast_classifier_decorator(orig_cast_classifier: Callable) -> Callable: @functools.wraps(orig_cast_classifier) def archer_cast_classifier(cls, *args, **kwargs): orig_data_ptr = cls.classifier.weight.data.data_ptr() - self.offload_set.remove(cls.classifier.weight.data.data_ptr()) - orig_cast_classifier(cls, *args, **kwargs) - new_data_ptr = cls.classifier.weight.data.data_ptr() - self.offload_set.add(cls.classifier.weight.data.data_ptr()) - self.archer_engine.update_tensor_map(orig_data_ptr, new_data_ptr) + if orig_data_ptr in self.offload_set: + self.offload_set.remove(cls.classifier.weight.data.data_ptr()) + orig_cast_classifier(cls, *args, **kwargs) + new_data_ptr = cls.classifier.weight.data.data_ptr() + self.offload_set.add(cls.classifier.weight.data.data_ptr()) + self.archer_engine.update_tensor_map(orig_data_ptr, new_data_ptr) + else: + orig_cast_classifier(cls, *args, **kwargs) + self.offload_set.add(cls.classifier.weight.data.data_ptr()) return archer_cast_classifier @@ -326,7 +331,8 @@ def archer_cast_classifier(cls, *args, **kwargs): if hasattr(module, "reset_parameters"): module._old_reset_parameters = module.reset_parameters module.reset_parameters = do_nothing_decorator(module.reset_parameters) - + + transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._old_cast_classifier = transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._cast_classifier transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._cast_classifier = cast_classifier_decorator(transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._cast_classifier) transformers.models.switch_transformers.modeling_switch_transformers._old_sparse_mlp = ( @@ -588,17 +594,16 @@ def archer_from_pretrained(cls, *args, **kwargs): return self + # clean up initialization hooks def __exit__(self, exc_type, exc_value, traceback): - # self.cls._load_pretrained_model = self.cls._old_load_pretrained_model - # self.cls.from_pretrained = self.cls._old_from_pretrained self.cls.__init__ = self.cls._old_init + self.cls.from_pretrained = self.cls._old_from_pretrained + torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply + torch.index_select = torch._old_index_select + torch.Tensor.index_select = torch.Tensor._old_index_select + self.cls.post_init = self.cls._old_post_init PreTrainedModel.post_init = PreTrainedModel._old_post_init - # self.cls.config_class.from_pretrained = ( - # self.cls.config_class._old_from_pretrained) - # transformers.modeling_utils.load_state_dict = ( - # transformers.modeling_utils.old_load_state_dict) - torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply for name, module in torch.nn.modules.__dict__.items(): if not isinstance(module, type): @@ -619,13 +624,6 @@ def __exit__(self, exc_type, exc_value, traceback): if hasattr(module, "reset_parameters"): module.reset_parameters = module._old_reset_parameters - transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersSparseMLP = ( - transformers.models.switch_transformers.modeling_switch_transformers._old_sparse_mlp - ) - transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeSparseMLP = ( - transformers.models.nllb_moe.modeling_nllb_moe._old_sparse_mlp - ) - def get_topology(self, model): name_lst = [] ret_dict = {} @@ -959,3 +957,24 @@ def _post_forward_module_hook(module, input, output): self.forward_hooks.append( module.register_forward_hook(_post_forward_module_hook) ) + + # clean runtime hooks + def clean_up(self): + transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._cast_classifier = transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._old_cast_classifier + transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersSparseMLP = ( + transformers.models.switch_transformers.modeling_switch_transformers._old_sparse_mlp + ) + + transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeSparseMLP = ( + transformers.models.nllb_moe.modeling_nllb_moe._old_sparse_mlp + ) + + transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock = ( + transformers.models.mixtral.modeling_mixtral._old_sparse_mlp + ) + + moe_infinity.modeling_grok.modeling_grok1.MoeBlock = ( + moe_infinity.modeling_grok.modeling_grok1._old_sparse_mlp + ) + + \ No newline at end of file