From 350f0dd6501b6ca989e605b356deb5d1575e8c2e Mon Sep 17 00:00:00 2001 From: Leyang Xue Date: Thu, 15 Aug 2024 20:51:42 +0100 Subject: [PATCH 1/4] add override QuantLinear (#29) Co-authored-by: xly --- moe_infinity/runtime/model_offload.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/moe_infinity/runtime/model_offload.py b/moe_infinity/runtime/model_offload.py index 6e8102b..1750b6e 100644 --- a/moe_infinity/runtime/model_offload.py +++ b/moe_infinity/runtime/model_offload.py @@ -9,6 +9,8 @@ import math import torch.distributed as dist from torch.distributed import rpc +from auto_gptq.nn_modules.qlinear.qlinear_cuda import QuantLinear +from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import QuantLinear as QuantLinearOld import torch import functools @@ -280,6 +282,13 @@ def archer_cast_classifier(cls, *args, **kwargs): self.offload_set.add(cls.classifier.weight.data.data_ptr()) return archer_cast_classifier + + + # GPTQ Override + QuantLinear._old_init = QuantLinear.__init__ + QuantLinear.__init__ = param_init_decorator(QuantLinear.__init__) + QuantLinearOld._old_init = QuantLinearOld.__init__ + QuantLinearOld.__init__ = param_init_decorator(QuantLinearOld.__init__) self.cls._old_init = self.cls.__init__ self.cls.__init__ = init_decorator(self.cls._old_init) @@ -605,6 +614,11 @@ def archer_from_pretrained(cls, *args, **kwargs): # clean up initialization hooks def __exit__(self, exc_type, exc_value, traceback): + + # GPTQ Override + QuantLinear.__init__ = QuantLinear._old_init + QuantLinearOld.__init__ = QuantLinearOld._old_init + self.cls.__init__ = self.cls._old_init self.cls.from_pretrained = self.cls._old_from_pretrained torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply From 2d21aba6923653e2fb8653728cc5d87e5c0db875 Mon Sep 17 00:00:00 2001 From: xly Date: Tue, 26 Nov 2024 15:49:22 +0000 Subject: [PATCH 2/4] use torch streampool --- core/memory/stream_pool.cpp | 4 +- core/memory/stream_pool.h | 16 +- core/model/model_topology.cpp | 16 +- core/model/model_topology.h | 2 +- core/parallel/expert_dispatcher.cpp | 2 +- core/prefetch/task_scheduler.cpp | 4 +- core/prefetch/task_scheduler.h | 1 + core/trace/archer_tensor_tracer.cpp | 140 ------ core/trace/archer_tensor_tracer.h | 30 -- core/trace/model_topology.cpp | 699 ---------------------------- core/trace/model_topology.h | 203 -------- 11 files changed, 24 insertions(+), 1093 deletions(-) delete mode 100644 core/trace/archer_tensor_tracer.cpp delete mode 100644 core/trace/archer_tensor_tracer.h delete mode 100644 core/trace/model_topology.cpp delete mode 100644 core/trace/model_topology.h diff --git a/core/memory/stream_pool.cpp b/core/memory/stream_pool.cpp index 4a8ef15..843b2a5 100644 --- a/core/memory/stream_pool.cpp +++ b/core/memory/stream_pool.cpp @@ -6,5 +6,5 @@ #include "stream_pool.h" // Stream0 is used for H2D, Stream1 is used for Kernel, Stream2 is used for D2H -// CUDAStreamPool* kCUDAStreamPool = CUDAStreamPool::GetInstance(); -std::unique_ptr kCUDAStreamPool = std::make_unique(); +// TorchStreamPool* kTorchStreamPool = TorchStreamPool::GetInstance(); +std::unique_ptr kTorchStreamPool = std::make_unique(); diff --git a/core/memory/stream_pool.h b/core/memory/stream_pool.h index e778586..d3af584 100644 --- a/core/memory/stream_pool.h +++ b/core/memory/stream_pool.h @@ -10,14 +10,14 @@ #include "utils/cuda_utils.h" #include "utils/noncopyable.h" -class CUDAStreamPool : public noncopyable { +class TorchStreamPool : public noncopyable { public: std::vector& operator()(const int device_id) { return cuda_streams_[device_id]; } - CUDAStreamPool() + TorchStreamPool() { int num_devices = GetDeviceCount(); for (int i = 0; i < num_devices; ++i) { @@ -28,14 +28,14 @@ class CUDAStreamPool : public noncopyable { cuda_streams_.push_back(std::move(streams)); } } - virtual ~CUDAStreamPool() = default; + virtual ~TorchStreamPool() = default; private: std::vector> cuda_streams_; }; -extern std::unique_ptr kCUDAStreamPool; -#define CUDA_STREAM_VIEW(device_id, stream_id) (*kCUDAStreamPool)(device_id)[stream_id] -#define CUDA_STREAM_H2D_VIEW(device_id) CUDA_STREAM_VIEW(device_id, 0) -#define CUDA_STREAM_D2H_VIEW(device_id) CUDA_STREAM_VIEW(device_id, 1) -#define CUDA_STREAM_COMPUTE_VIEW(device_id) CUDA_STREAM_VIEW(device_id, 2) +extern std::unique_ptr kTorchStreamPool; +#define TORCH_STREAM_VIEW(device_id, stream_id) (*kTorchStreamPool)(device_id)[stream_id] +#define TORCH_STREAM_H2D_VIEW(device_id) TORCH_STREAM_VIEW(device_id, 0) +#define TORCH_STREAM_D2H_VIEW(device_id) TORCH_STREAM_VIEW(device_id, 1) +#define TORCH_STREAM_COMPUTE_VIEW(device_id) TORCH_STREAM_VIEW(device_id, 2) diff --git a/core/model/model_topology.cpp b/core/model/model_topology.cpp index 69e32d7..a089249 100644 --- a/core/model/model_topology.cpp +++ b/core/model/model_topology.cpp @@ -22,7 +22,7 @@ #include "prefetch/task_scheduler.h" #include "utils/archer_logger.h" -cudaStream_t kCudaStreamH2D = NULL; +// cudaStream_t kCudaStreamH2D = NULL; std::unique_ptr kTopologyHandle = nullptr; const std::string Node::str() noexcept @@ -73,13 +73,13 @@ void Node::SetDevice(const torch::Device& target_device, return; } - if (kCudaStreamH2D == NULL) { - auto cudaError = cudaStreamCreateWithFlags(&kCudaStreamH2D, cudaStreamNonBlocking); - if (cudaError != cudaSuccess) { - ARCHER_LOG_ERROR("cudaStreamCreate failed: {}", cudaGetErrorString(cudaError)); - exit(-1); - } - } + // if (kCudaStreamH2D == NULL) { + // auto cudaError = cudaStreamCreateWithFlags(&kCudaStreamH2D, cudaStreamNonBlocking); + // if (cudaError != cudaSuccess) { + // ARCHER_LOG_ERROR("cudaStreamCreate failed: {}", cudaGetErrorString(cudaError)); + // exit(-1); + // } + // } if (target_device == DISK_DEVICE) { SetModuleDisk(tensor_ids); diff --git a/core/model/model_topology.h b/core/model/model_topology.h index 91dda70..7478323 100644 --- a/core/model/model_topology.h +++ b/core/model/model_topology.h @@ -24,7 +24,7 @@ enum NodeState { NODE_STATE_VISITED = 0x4, }; -extern cudaStream_t kCudaStreamH2D; +// extern cudaStream_t kCudaStreamH2D; struct Node { std::vector tensor_ids; diff --git a/core/parallel/expert_dispatcher.cpp b/core/parallel/expert_dispatcher.cpp index 6996edc..93856d9 100644 --- a/core/parallel/expert_dispatcher.cpp +++ b/core/parallel/expert_dispatcher.cpp @@ -325,7 +325,7 @@ void ExpertDispatcher::GPUExecFunc(int gpu_id) at::InferenceMode infer_guard(true); c10::cuda::CUDAStream stream = - c10::cuda::getStreamFromExternal(fetch_streams_[gpu_id], gpu_id); + c10::cuda::getStreamFromExternal(exec_streams_[gpu_id], gpu_id); { c10::cuda::CUDAStreamGuard guard(stream); diff --git a/core/prefetch/task_scheduler.cpp b/core/prefetch/task_scheduler.cpp index 1cebafe..9f48fb4 100644 --- a/core/prefetch/task_scheduler.cpp +++ b/core/prefetch/task_scheduler.cpp @@ -9,6 +9,7 @@ #include "common/time.h" #include "memory/memory_pool.h" +#include "memory/stream_pool.h" #include "task_scheduler.h" #include "task_thread.h" #include "utils/archer_logger.h" @@ -518,7 +519,8 @@ void ArcherTaskPool::SetNodeDevice(const TaskPtr& task) auto start_time = MCIROSECONDS_SINCE_EPOCH; // node->SetDevice(task->dst_device); - node->SetDevice(task->dst_device, task->on_demand); + task->stream = TORCH_STREAM_H2D_VIEW(task->dst_device.index()).stream(); + node->SetDevice(task->dst_device, task->on_demand, task->stream); auto end_time = MCIROSECONDS_SINCE_EPOCH; ARCHER_LOG_DEBUG( "SetNodeDevice: task: {}, emplace time {} us", task->DebugString(), end_time - start_time); diff --git a/core/prefetch/task_scheduler.h b/core/prefetch/task_scheduler.h index 2016f27..da978c0 100644 --- a/core/prefetch/task_scheduler.h +++ b/core/prefetch/task_scheduler.h @@ -31,6 +31,7 @@ struct Task { std::uint64_t request_id; torch::Device src_device = DISK_DEVICE; torch::Device dst_device = DISK_DEVICE; + cudaStream_t stream = nullptr; bool remove_layer = false; diff --git a/core/trace/archer_tensor_tracer.cpp b/core/trace/archer_tensor_tracer.cpp deleted file mode 100644 index dd0679b..0000000 --- a/core/trace/archer_tensor_tracer.cpp +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) TorchMoE. -// SPDX-License-Identifier: Apache-2.0 - -// TorchMoE Team - -#include "archer_tensor_tracer.h" - -#include -#include - -#include "utils/archer_logger.h" - -void ArcherTensorTracer::AddTrace(const std::uint32_t layer_id, - const std::vector buffer) -{ - assert(request_id_ != UINT64_MAX); - if (traces_.find(request_id_) == traces_.end()) { - traces_[request_id_] = std::vector>>(); - } - traces_[request_id_].push_back({layer_id, buffer}); - - max_layer_id_ = std::max(max_layer_id_, layer_id); -} - -void ArcherTensorTracer::ClearRequestID() -{ - if (traces_.size() > 1000) { traces_.erase(request_id_); } - request_id_ = UINT64_MAX; -} - -std::vector ArcherTensorTracer::GetCandidates(std::uint32_t layer_id) -{ - if (traces_.size() == 1) { return {}; } - - std::vector< - std::pair>>, double>> - candidates; - - for (auto& req_item : traces_[request_id_]) { - auto req_layer = req_item.first; - auto req_tensors = req_item.second; - - for (auto& hist_item : traces_) { - auto hist_req_id = hist_item.first; - auto value = hist_item.second; - - if (hist_req_id == request_id_) continue; - if (layer_id < value[0].first) continue; - - for (auto& value_item : value) { - auto hist_layer = value_item.first; - auto hist_tensors = value_item.second; - - if (req_layer != hist_layer) continue; - - std::vector overlap; - std::set_intersection(req_tensors.begin(), - req_tensors.end(), - hist_tensors.begin(), - hist_tensors.end(), - std::back_inserter(overlap)); - - std::vector total; - std::set_union(req_tensors.begin(), - req_tensors.end(), - hist_tensors.begin(), - hist_tensors.end(), - std::back_inserter(total)); - - double prob = - static_cast(overlap.size()) / total.size() / (layer_id - req_layer + 1); - - std::vector>> candidate_value; - for (auto& pair : value) { - if (pair.first > layer_id) { candidate_value.push_back(pair); } - } - - if (!candidate_value.empty()) { candidates.push_back({candidate_value, prob}); } - break; - } - } - } - - if (candidates.empty()) { return {}; } - - std::unordered_map> - tensor_probs; // , prob - for (auto& item : candidates) { - auto layer_tensors = item.first; - auto prob = item.second; - - for (auto& layer_item : layer_tensors) { - auto layer = layer_item.first; - auto tensors = layer_item.second; - if (tensor_probs.find(layer) == tensor_probs.end()) { - tensor_probs[layer] = std::unordered_map(); - } - auto& layer_tensor_probs = tensor_probs[layer]; - for (std::uint32_t tensor : tensors) { - if (layer_tensor_probs.find(tensor) == layer_tensor_probs.end()) { - layer_tensor_probs[tensor] = 0; - } - layer_tensor_probs[tensor] += prob; - } - } - } - - // find top 10 tensors for each layer id - std::vector tensor_ids; - for (auto& item : tensor_probs) { - // auto layer = item.first; - auto& layer_tensor_probs = item.second; - - // if (layer < layer_id + 2) continue; // skip the layers that are too close to the current - // layer - - std::vector> layer_tensor_probs_vec( - layer_tensor_probs.begin(), layer_tensor_probs.end()); - - layer_tensor_probs_vec.erase( - std::remove_if( - layer_tensor_probs_vec.begin(), - layer_tensor_probs_vec.end(), - [](const std::pair& pair) { return pair.second < 0.01; }), - layer_tensor_probs_vec.end()); - - if (layer_tensor_probs_vec.empty()) { continue; } - - std::sort(layer_tensor_probs_vec.begin(), - layer_tensor_probs_vec.end(), - [](const std::pair& a, - const std::pair& b) { return a.second > b.second; }); - - std::size_t width = 20; - for (std::uint32_t i = 0; i < std::min(layer_tensor_probs_vec.size(), width); ++i) { - tensor_ids.push_back(layer_tensor_probs_vec[i].first); - } - } - return tensor_ids; -} diff --git a/core/trace/archer_tensor_tracer.h b/core/trace/archer_tensor_tracer.h deleted file mode 100644 index c91ba94..0000000 --- a/core/trace/archer_tensor_tracer.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) TorchMoE. -// SPDX-License-Identifier: Apache-2.0 - -// TorchMoE Team - -#pragma once - -#include -#include -#include -#include - -class ArcherTensorTracer { -public: - // ArcherTensorTracer(const std::string& prefix); - - void SetRequestID(const std::uint64_t& request_id) { request_id_ = request_id; } - std::uint64_t GetRequestID() { return request_id_; } - void ClearRequestID(); - - void AddTrace(const std::uint32_t layer_id, const std::vector buffer); - std::vector GetCandidates(std::uint32_t layer_id); - -private: - std::uint64_t request_id_; - std::unordered_map>>> - traces_; - std::uint32_t max_layer_id_ = 0; -}; diff --git a/core/trace/model_topology.cpp b/core/trace/model_topology.cpp deleted file mode 100644 index 54d33e2..0000000 --- a/core/trace/model_topology.cpp +++ /dev/null @@ -1,699 +0,0 @@ -// Copyright (c) TorchMoE. -// SPDX-License-Identifier: Apache-2.0 - -// TorchMoE Team - -#include "model_topology.h" - -#include -#include -#include -#include -#include -#include -#include "aio/archer_prio_aio_handle.h" -#include "aio/archer_tensor_handle.h" -#include "aio/archer_tensor_index.h" -#include "common/time.h" -#include "common/types.h" -#include "memory/memory_pool.h" -#include "memory/stream_pool.h" -#include "parallel/expert_dispatcher.h" -#include "prefetch/task_scheduler.h" -#include "utils/archer_logger.h" - -cudaStream_t kCudaStreamH2D = NULL; -std::unique_ptr kTopologyHandle = nullptr; - -const std::string Node::str() noexcept -{ - // write same string using c style sprintf - std::stringstream ss; - for (auto& tensor_id : tensor_ids) { ss << tensor_id << ","; } - - char buffer[1024]; - memset(buffer, 0, 1024); - sprintf(buffer, - "ID[%ld,%lx] (%ldMB) STATE(%d) TENSOR[%s] DEVICE[%s;%s;%s];", - id, - corr_id, - byte_size / MB, - state.load(), - ss.str().c_str(), - device.str().c_str(), - default_device.str().c_str(), - default_host.str().c_str()); - - return std::string(buffer); -} - -Node::Node() - : corr_id(0), - byte_size(0), - last_access_time(MCIROSECONDS_SINCE_EPOCH), - device(DISK_DEVICE), - default_device(DEFAULT_CUDA_DEVICE) -{ -} - -void Node::SetDevice(const torch::Device& target_device, - bool on_demand, - cudaStream_t stream) noexcept -{ - ARCHER_LOG_DEBUG("SetDevice: ", str(), " to ", target_device.str()); - if (device == target_device) { - ARCHER_LOG_DEBUG("SetDevice: " + str() + " to " + target_device.str() + - " but device is the same"); - return; - } - - if (device.type() == target_device.type()) { - ARCHER_LOG_WARN("SetDevice: " + str() + " to " + target_device.str() + - " but device type is the same"); - return; - } - - if (kCudaStreamH2D == NULL) { - auto cudaError = cudaStreamCreateWithFlags(&kCudaStreamH2D, cudaStreamNonBlocking); - if (cudaError != cudaSuccess) { - ARCHER_LOG_ERROR("cudaStreamCreate failed: ", cudaGetErrorString(cudaError)); - exit(-1); - } - } - - if (target_device == DISK_DEVICE) { - SetModuleDisk(tensor_ids); - if (host_memory_ptr != nullptr) { - kHostMemoryPool->FreeMemory(id, host_memory_ptr, byte_size, CPU_DEVICE); - host_memory_ptr = nullptr; - } - if (device_memory_ptr != nullptr) { - kDeviceMemoryPool->FreeMemory(id, device_memory_ptr, byte_size, device); - device_memory_ptr = nullptr; - } - } else { - // both are null, which means the node is not initialized - if (host_memory_ptr == nullptr && device_memory_ptr == nullptr) { - // int numa_id = - // default_device.index() / 4; // TODO: 8 gpus, 2 numa nodes, so 4 gpus per numa - host_memory_ptr = kHostMemoryPool->AllocateMemory(id, byte_size, CPU_DEVICE); - assert(host_memory_ptr != nullptr); - - auto start_time = MCIROSECONDS_SINCE_EPOCH; - SetModuleMemoryFromDisk(tensor_ids, host_memory_ptr, on_demand); - auto end_time = MCIROSECONDS_SINCE_EPOCH; - ARCHER_LOG_DEBUG("SetModuleMemoryFromDisk time:", end_time - start_time, " us"); - } - - if (target_device.is_cuda()) { - // ARCHER_LOG_DEBUG("Allocate GPU Memory for node {}", this->id); - device_memory_ptr = kDeviceMemoryPool->AllocateMemory(id, byte_size, target_device); - // ARCHER_LOG_DEBUG("Allocate GPU Memory for node {} done", this->id); - assert(device_memory_ptr != nullptr); - assert(host_memory_ptr != nullptr); - - auto start_time = MCIROSECONDS_SINCE_EPOCH; - if (stream == nullptr) { - cudaMemcpy(device_memory_ptr, host_memory_ptr, byte_size, cudaMemcpyHostToDevice); - } else { - cudaMemcpyAsync( - device_memory_ptr, host_memory_ptr, byte_size, cudaMemcpyHostToDevice, stream); - cudaStreamSynchronize(stream); - } - SetModuleCudaMemoryFromCPU(tensor_ids, device_memory_ptr, target_device); - auto end_time = MCIROSECONDS_SINCE_EPOCH; - ARCHER_LOG_DEBUG("SetModuleCudaMemoryFromCPU time: {} us", end_time - start_time); - } - - if (target_device.is_cpu() && device.is_cuda()) { - assert(host_memory_ptr != nullptr); - auto start_time = MCIROSECONDS_SINCE_EPOCH; - SetModuleMemoryFromCuda(tensor_ids, host_memory_ptr); - kDeviceMemoryPool->FreeMemory(id, device_memory_ptr, byte_size, device); - device_memory_ptr = nullptr; - auto end_time = MCIROSECONDS_SINCE_EPOCH; - ARCHER_LOG_DEBUG("SetModuleMemoryFromCuda time: {} us", end_time - start_time); - } - } - device = target_device; -} - -ArcherTopologyHandle::ArcherTopologyHandle() {} - -NodePtrList ArcherTopologyHandle::GetLFUNodes(const torch::Device& device) -{ - NodePtrList nodes; - std::lock_guard lock(mutex_); - for (auto node_body : lfu_nodes_) { - CONTINUE_IF_NULL(node_body); - if (node_body->node->device == device) { nodes.push_back(node_body->node); } - } - return nodes; -} - -NodePtrList ArcherTopologyHandle::GetDenseNodes() -{ - NodePtrList nodes; - for (auto stage : pipeline_.stages) { - if (stage->is_sparse) { continue; } - for (auto node_body : stage->nodes) { nodes.push_back(node_body->node); } - } - return nodes; -} -NodePtrList ArcherTopologyHandle::GetSparseNodes() -{ - NodePtrList nodes; - for (auto stage : pipeline_.stages) { - if (!stage->is_sparse) { continue; } - for (auto node_body : stage->nodes) { nodes.push_back(node_body->node); } - } - return nodes; -} - -NodePtrList ArcherTopologyHandle::GetDenseNodes(const NodePtr& node, const std::size_t& k) -{ - NodePtrList nodes; - - std::size_t low_corr_id = node->corr_id & 0xFFFFFFFF; // stage id - std::size_t high_corr_id = node->corr_id >> 32; // node id - bool is_last_node = (0xFFFFFFFF == high_corr_id); - if (is_last_node) { - high_corr_id = 0; // reset to 0 avoid miss use - } - - std::lock_guard lock(mutex_); - - low_corr_id++; - std::size_t count = 0; - while ((low_corr_id < pipeline_.stages.size()) && (count < k)) { - // Due to MoE design, we only process layer by layer - auto stage = pipeline_.stages[low_corr_id]; - low_corr_id++; - if (stage->is_sparse) { continue; } - - nodes.push_back(stage->nodes[0]->node); - count++; - } - return nodes; -} - -NodePtrList ArcherTopologyHandle::GetSparseNodes(const NodePtr& node, const std::size_t& k) -{ - NodePtrList nodes; - - std::size_t low_corr_id = node->corr_id & 0xFFFFFFFF; // stage id - std::size_t high_corr_id = node->corr_id >> 32; // node id - bool is_last_node = (0xFFFFFFFF == high_corr_id); - if (is_last_node) { - high_corr_id = 0; // reset to 0 avoid miss use - } - - std::lock_guard lock(mutex_); - - low_corr_id++; - std::size_t count = 0; - while ((low_corr_id < pipeline_.stages.size()) && (count < k)) { - // Due to MoE design, we only process layer by layer - auto stage = pipeline_.stages[low_corr_id]; - - low_corr_id++; - if (!stage->is_sparse) { continue; } - - nodes.push_back(stage->nodes[0]->node); - count++; - } - return nodes; -} - -std::uint64_t ArcherTopologyHandle::GetLastActivateStage(const HashID& hash_id) -{ - std::lock_guard lock(mutex_); - auto it = last_active_stage_.find(hash_id); - if (it == last_active_stage_.end()) { return 0; } - return it->second; -} - -std::vector> ArcherTopologyHandle::GetNodeVisitCounts() -{ - std::lock_guard lock(mutex_); - std::vector> node_visit_counts; - for (auto& stage : pipeline_.stages) { - for (auto& node : stage->nodes) { - node->node->io_state = NODE_STATE_NONE; - std::vector metrics{node->visit_cnt, - node->gpu_visit_cnt, - node->cpu_visit_cnt, - node->hit_cnt, - node->gpu_hit_cnt, - node->cpu_hit_cnt, - node->node->tensor_ids.size(), - node->prefetch_cnt, - node->node->unused_count, - node->node->io_state, - node->is_sparse}; - node_visit_counts.push_back(metrics); - } - } - return node_visit_counts; -} - -std::vector ArcherTopologyHandle::GetChildVisitCounts() -{ - std::lock_guard lock(mutex_); - int num_layers = 0; - int num_experts = 0; - for (auto& stage : pipeline_.stages) { - if (stage->is_sparse) { - num_layers += 1; - num_experts = stage->nodes.size(); - } - } - std::vector child_visit_counts((num_layers - 1) * num_experts * num_experts); - int layer_idx = 0; - int parent_idx = 0; - int expert_idx = 0; - for (auto& stage : pipeline_.stages) { - if (stage->is_sparse) { - for (auto& node : stage->nodes) { - if (node->children.size() > 0) { - for (auto& count : node->children_visit_cnt) { - child_visit_counts[layer_idx * num_experts * num_experts + - parent_idx * num_experts + expert_idx] = count; - expert_idx++; - } - } - parent_idx++; - expert_idx = 0; - } - layer_idx++; - parent_idx = 0; - } - } - - return child_visit_counts; -} - -void ArcherTopologyHandle::SetNodeVisitCounts(const std::vector& visit_counts) -{ - std::lock_guard lock(mutex_); - std::size_t num_nodes = 0; - std::size_t num_experts = 0; - for (auto& stage : pipeline_.stages) { - if (stage->is_sparse) { - num_nodes += stage->nodes.size(); - num_experts = stage->nodes.size(); - } - } - if (visit_counts.size() != num_nodes) { - ARCHER_LOG_ERROR( - "visit_counts size {} not equal to num_nodes {}", visit_counts.size(), num_nodes); - return; - } - - int layer_idx = 0; - int expert_idx = 0; - for (auto& stage : pipeline_.stages) { - if (stage->is_sparse) { - for (auto& node : stage->nodes) { - node->visit_cnt = visit_counts[layer_idx * num_experts + expert_idx]; - expert_idx++; - } - layer_idx++; - expert_idx = 0; - } - } - - DisableTrace(); -} -void ArcherTopologyHandle::SetChildVisitCounts(const std::vector& visit_counts) -{ - std::lock_guard lock(mutex_); - std::size_t num_layers = 0; - std::size_t num_experts = 0; - for (auto& stage : pipeline_.stages) { - if (stage->is_sparse) { - num_layers += 1; - num_experts = stage->nodes.size(); - } - } - if (visit_counts.size() != (num_layers - 1) * num_experts * num_experts) { - ARCHER_LOG_ERROR( - "visit_counts size {} not equal to num_layers {}", visit_counts.size(), num_layers); - return; - } - - int layer_idx = 0; - int parent_idx = 0; - int expert_idx = 0; - for (auto& stage : pipeline_.stages) { - if (stage->is_sparse) { - for (auto& node : stage->nodes) { - if (node->children.size() > 0) { - for (auto& count : node->children_visit_cnt) { - count = visit_counts[layer_idx * num_experts * num_experts + - parent_idx * num_experts + expert_idx]; - expert_idx++; - } - } - parent_idx++; - expert_idx = 0; - } - layer_idx++; - parent_idx = 0; - } - } - - DisableTrace(); -} - -bool ArcherTopologyHandle::IsLastNode(const NodePtr& node) -{ - std::lock_guard lock(mutex_); - auto last_stage_ptr = pipeline_.stages.back(); - auto& nodes = last_stage_ptr->nodes; - for (auto& n : nodes) { - if (n->node == node) { return true; } - } - return false; -} -bool ArcherTopologyHandle::IsFirstNode(const NodePtr& node) -{ - std::lock_guard lock(mutex_); - auto first_stage_ptr = pipeline_.stages.front(); - auto& nodes = first_stage_ptr->nodes; - for (auto& n : nodes) { - if (n->node == node) { return true; } - } - return false; -} - -void ArcherTopologyHandle::InitializeTopology( - const std::vector>>>& topology) -{ - std::lock_guard lock(mutex_); - pipeline_.stages.clear(); - std::size_t node_id = 0; - std::size_t layer_id = 0; - std::size_t last_sparse_layer_id = UINT64_MAX; - - size_t num_sparse_layers = 0; - size_t num_experts = 0; - - std::vector all_nodes; - - for (auto& stage : topology) { - auto& stage_tensors = std::get<1>(stage); - auto stage_ptr = std::make_shared(stage_tensors.size() > 1); - - std::size_t expert_id = 0; - for (auto& tensor_ids : stage_tensors) { - auto node_ptr = std::make_shared(); - node_ptr->tensor_ids = tensor_ids; - int64_t byte_size = 0; - for (auto& tensor_id : tensor_ids) { - auto it = kTensorIndex->find(tensor_id); - if (it != kTensorIndex->end()) { - std::int64_t size_aligned = - (it->second.size + kAioAlignment - 1) & ~(kAioAlignment - 1); - byte_size += size_aligned; - } else { - ARCHER_LOG_ERROR("Tensor {} not found in tensor index", tensor_id); - } - } - node_ptr->byte_size = byte_size; - node_ptr->id = node_id; - node_ptr->corr_id = (layer_id & 0xFFFFFFFF) | ((expert_id & 0xFFFFFFFF) << 32); - node_ptr->is_sparse = stage_ptr->is_sparse; - - all_nodes.push_back(node_ptr); - - auto node_body_ptr = std::make_shared(node_ptr); - node_body_ptr->is_sparse = stage_ptr->is_sparse; - - stage_ptr->nodes.push_back(node_body_ptr); - - node_id++; - expert_id++; - } - pipeline_.stages.push_back(stage_ptr); - auto current_layer_id = layer_id; - layer_id++; - - if (stage_ptr->is_sparse) { - if (UINT64_MAX == last_sparse_layer_id) { - last_sparse_layer_id = current_layer_id; - continue; - } - // set node_body_ptr vectors to be the same size as the number of experts - // all counts initialized to 0 - auto last_sparse_stage_ptr = pipeline_.stages[last_sparse_layer_id]; - for (auto& node : last_sparse_stage_ptr->nodes) { - node->children_visit_cnt.resize(stage_ptr->nodes.size(), 0); - node->children = stage_ptr->nodes; - } - last_sparse_layer_id = current_layer_id; - - num_sparse_layers++; - num_experts = stage_ptr->nodes.size(); - } - } - - // set last stage nodes corr_id higher 32 bits to be 0xFFFFFFFF - auto last_stage_ptr = pipeline_.stages.back(); - for (auto& node_body : last_stage_ptr->nodes) { - node_body->node->corr_id = (node_body->node->corr_id & 0xFFFFFFFF) | (UINT64_MAX << 32); - } - - // output every tensor id in node - for (auto& stage : pipeline_.stages) { - for (auto& node : stage->nodes) { - std::stringstream ss; - for (auto& tensor_id : node->node->tensor_ids) { ss << tensor_id << " "; } - // ARCHER_LOG_DEBUG("Node {} tensor ids {}", node->node->id, ss.str()); - lfu_nodes_.push_back(node); - } - } - - ARCHER_LOG_DEBUG("InitializeTopology pipeline_.stages.size() {}", pipeline_.stages.size()); - - // Model placement - auto num_gpu = GetDeviceCount(); - std::vector free_device_mem(num_gpu, 0); - for (int i = 0; i < num_gpu; i++) { - free_device_mem[i] = kDeviceMemoryPool->GetMemoryCapacity(torch::Device(torch::kCUDA, i)); - } - - auto sparse_nodes = GetSparseNodes(); - auto dense_nodes = GetDenseNodes(); - - ARCHER_LOG_DEBUG("InitializeTopology num_gpu {} sparse_nodes.size() {} dense_nodes.size() {}", - num_gpu, - sparse_nodes.size(), - dense_nodes.size()); - - int target_device_id = 0; - int dense_gpu_idx = 0; - int sparse_gpu_idx = 0; - - // Split evently dense nodes only - int num_dense_nodes_per_device = std::ceil(dense_nodes.size() / num_gpu / 2); - // int total_dense_nodes = dense_nodes.size(); - int counter = 0; - for (auto& node_ptr : dense_nodes) { - // split dense node evenly among GPUs - node_ptr->default_device = torch::Device(torch::kCUDA, target_device_id); - counter++; - if (counter % num_dense_nodes_per_device == 0) { - target_device_id = (target_device_id + 1) % num_gpu; - } - } - dense_nodes.back()->default_device = torch::Device(torch::kCUDA, num_gpu - 1); - - // split evenly sparse nodes among GPUs - for (auto& node_ptr : sparse_nodes) { - node_ptr->default_device = torch::Device(torch::kCUDA, target_device_id); - target_device_id = (target_device_id + 1) % num_gpu; - } - - ARCHER_LOG_DEBUG("InitializeTopology pipeline_.stages.size() {}", pipeline_.stages.size()); - - for (auto& node_ptr : all_nodes) { - ARCHER_LOG_DEBUG("Node {} {} device {}", - node_ptr->id, - node_ptr->is_sparse, - node_ptr->default_device.str()); - } - - EnableTrace(); -} - -NodePtr ArcherTopologyHandle::GetNodeFromTensorID(const TensorID& tensor_id) -{ - std::lock_guard lock(mutex_); - - auto it = tensor_id_to_node_.find(tensor_id); - if (it != tensor_id_to_node_.end()) { - return it->second; - } else { - // search in pipeline - for (auto& stage : pipeline_.stages) { - for (auto& node_body : stage->nodes) { - for (auto& id : node_body->node->tensor_ids) { - if (id == tensor_id) { - tensor_id_to_node_[tensor_id] = node_body->node; - return node_body->node; - } - } - } - } - } - ARCHER_LOG_ERROR("Tensor {} not found in tensor id to node map", tensor_id); - return nullptr; -} - -NodeBodyPtr ArcherTopologyHandle::GetNodeBodyFromCorrID(const std::uint64_t& correlation_id) -{ - std::lock_guard lock(mutex_); - - std::uint64_t high_corr_id = correlation_id >> 32; // For children in the same level - std::uint64_t low_corr_id = correlation_id & 0xFFFFFFFF; // For model inference pipeline - - bool is_last_node = (0xFFFFFFFF == high_corr_id); - if (is_last_node) { - high_corr_id = 0; // reset to 0 avoid miss use - } - - auto stage = pipeline_.stages[low_corr_id]; - auto node_body = stage->nodes[high_corr_id]; - - return node_body; -} - -std::int64_t ArcherTopologyHandle::GetSparseCacheLimit(const torch::Device& device) -{ - std::int64_t dense_cache_size = 0; - for (auto& stage : pipeline_.stages) { - for (auto& node_body : stage->nodes) { - if (stage->is_sparse) continue; - if (node_body->node->device == device) { - dense_cache_size += node_body->node->byte_size; - } - } - } - - std::int64_t device_size_limit = (device.is_cuda()) - ? kDeviceMemoryPool->GetMemoryCapacity(device) - : kHostMemoryPool->GetMemoryCapacity(); - assert(device_size_limit > dense_cache_size); - std::int64_t sparse_cache_size = device_size_limit - dense_cache_size; - - return sparse_cache_size; -} - -std::tuple ArcherTopologyHandle::GetNumLayersAndExperts() -{ - std::lock_guard lock(mutex_); - int num_layers = 0; - int num_experts = 0; - for (auto& stage : pipeline_.stages) { - if (stage->is_sparse) { - num_layers += 1; - num_experts = stage->nodes.size(); - } - } - return std::make_tuple(num_layers, num_experts); -} - -// CPU, GPU -> DISK -// Moves tensors from CPU/GPU to disk. -void SetModuleDisk(std::vector& tensor_ids) -{ - // ARCHER_LOG_DEBUG("SetModuleDisk {} tensors", tensor_ids.size()); - for (const auto& tensor_id : tensor_ids) { - // void* old_ptr = kTensorIndex->find(tensor_id)->second.tensor.data_ptr(); - auto it = kTensorIndex->find(tensor_id); - - at::TensorOptions options; - options = options.device(torch::kCPU); - options = options.dtype(it->second.tensor.dtype()); - auto tensor = torch::zeros({1}, options); - it->second.tensor.set_data(tensor); - } -} - -std::mutex kReadMutex; - -// DISK -> CPU -void SetModuleMemoryFromDisk(std::vector& tensor_ids, void* host_ptr, bool on_demand) -{ - std::int64_t param_size = 0; - for (const auto& tensor_id : tensor_ids) { - // void* old_ptr = kTensorIndex->find(tensor_id)->second.tensor.data_ptr(); - kArcherTensorHandle->ReadTensor( - tensor_id, (void*)((char*)host_ptr + param_size), on_demand); - auto it = kTensorIndex->find(tensor_id); - auto options = torch::TensorOptions() - .dtype(it->second.options.dtype()) - .layout(it->second.options.layout()) - .device(torch::kCPU) - .requires_grad(it->second.options.requires_grad()) - .pinned_memory(it->second.options.pinned_memory()); - - ARCHER_LOG_DEBUG("SetModuleMemoryFromDisk tensor {}", it->second.DebugString()); - auto tensor_tmp = torch::from_blob((void*)((char*)host_ptr + param_size), - it->second.shape, - DoNothingDeleter{}, - options); - if (!it->second.tensor.defined()) { it->second.tensor = torch::zeros({1}, options); } - it->second.tensor.set_data(tensor_tmp); - std::int64_t size_aligned = (it->second.size + kAioAlignment - 1) & ~(kAioAlignment - 1); - param_size += size_aligned; - } -} - -// CPU -> GPU -void SetModuleCudaMemoryFromCPU(std::vector& tensor_ids, - void* device_ptr, - const torch::Device& device) -{ - // ARCHER_LOG_DEBUG("SetModuleCudaMemoryFromCPU {} tensors", tensor_ids.size()); - std::int64_t param_size = 0; - for (const auto& tensor_id : tensor_ids) { - auto it = kTensorIndex->find(tensor_id); - ARCHER_LOG_DEBUG( - "SetModuleCudaMemoryFromCPU tensor {} -> {}", it->second.DebugString(), device.str()); - auto tensor_options = torch::TensorOptions() - .dtype(it->second.options.dtype()) - .layout(it->second.options.layout()) - .device(device) - .requires_grad(it->second.options.requires_grad()) - .pinned_memory(false); - it->second.tensor.set_data(torch::from_blob((char*)device_ptr + param_size, - it->second.shape, - DoNothingDeleter{}, - tensor_options)); - std::int64_t size_aligned = (it->second.size + kAioAlignment - 1) & ~(kAioAlignment - 1); - param_size += size_aligned; - } - // ARCHER_LOG_DEBUG("SetModuleCudaMemoryFromCPU {} tensors done", tensor_ids.size()); -} - -// GPU -> CPU -void SetModuleMemoryFromCuda(std::vector& tensor_ids, void* host_ptr) -{ - std::int64_t param_size = 0; - for (const auto& tensor_id : tensor_ids) { - // void* old_ptr = kTensorIndex->find(tensor_id)->second.tensor.data_ptr(); - - auto it = kTensorIndex->find(tensor_id); - ARCHER_LOG_DEBUG("SetModuleMemoryFromCuda tensor ", it->second.DebugString()); - it->second.tensor.set_data(torch::from_blob((char*)host_ptr + param_size, - it->second.shape, - DoNothingDeleter{}, - it->second.options)); - // kArcherTensorHandle->UpdateTensorMap(old_ptr, it->second.tensor.data_ptr()); - std::int64_t size_aligned = (it->second.size + kAioAlignment - 1) & ~(kAioAlignment - 1); - param_size += size_aligned; - } - // ARCHER_LOG_DEBUG("SetModuleMemoryFromCuda {} tensors done", tensor_ids.size()); -} diff --git a/core/trace/model_topology.h b/core/trace/model_topology.h deleted file mode 100644 index 2cd417c..0000000 --- a/core/trace/model_topology.h +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) TorchMoE. -// SPDX-License-Identifier: Apache-2.0 - -// TorchMoE Team - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "common/pytorch.h" -#include "common/types.h" -#include "memory/memory_pool.h" -#include "utils/noncopyable.h" - -enum NodeState { - NODE_STATE_NONE = 0x0, - NODE_STATE_CACHED = 0x1, - NODE_STATE_PREFETCHED = 0x2, - NODE_STATE_VISITED = 0x4, -}; - -extern cudaStream_t kCudaStreamH2D; - -struct Node { - std::vector tensor_ids; - std::int64_t byte_size; - std::size_t last_access_time; - std::size_t last_prefetch_time = 0; - - std::size_t id; - std::size_t corr_id; - - torch::Device device = DISK_DEVICE; - torch::Device default_device = DEFAULT_CUDA_DEVICE; // FIXME: should be set by scheduler - torch::Device default_host = CPU_DEVICE; - torch::Device initial_host = DISK_DEVICE; - - std::atomic_uint8_t state{0}; // 0 for ready, 1 for moving - - std::mutex mutex; - std::condition_variable cv; - - std::uint64_t visit_count = 0; - std::uint64_t unused_count = 0; - bool is_sparse = false; - NodeState io_state = NODE_STATE_NONE; - - bool is_overflow = false; - - void* host_memory_ptr = nullptr; - void* device_memory_ptr = nullptr; - -public: - explicit Node(); - const std::string str() noexcept; - void SetDevice(const torch::Device& target_device, - bool ondemand = false, - cudaStream_t stream = nullptr) noexcept; -}; - -typedef std::shared_ptr NodePtr; -typedef std::vector NodePtrList; -typedef std::tuple FilterResult; - -struct NodeBody; -typedef std::shared_ptr NodeBodyPtr; - -struct NodeBody { - NodePtr node; - std::vector children; - std::vector children_visit_cnt; - std::unordered_set activate_request; - std::size_t prefetch_cnt = 0; - std::size_t visit_cnt = 0; - std::size_t cpu_visit_cnt = 0; - std::size_t gpu_visit_cnt = 0; - std::size_t hit_cnt = 0; - std::size_t gpu_hit_cnt = 0; - std::size_t cpu_hit_cnt = 0; - std::size_t gpu_miss_cnt = 0; - std::size_t cpu_miss_cnt = 0; - bool is_sparse; - std::deque visit_time; - explicit NodeBody(NodePtr node) : node(node), visit_cnt(0) {} - - std::string str() const noexcept - { - std::stringstream ss; - ss << "NodeBody: " << node->str() << " visit_cnt " << visit_cnt << ", child visit ["; - for (auto& visit : children_visit_cnt) { ss << visit << ","; } - ss << "]"; - return ss.str(); - } -}; - -struct Stage { - bool is_sparse; - std::vector nodes; - std::size_t visit_cnt; - std::int64_t byte_size; - std::deque visit_time; - std::unordered_set activate_request; - Stage() : is_sparse(false), visit_cnt(0), byte_size(0) {} - Stage(bool is_sparse) : is_sparse(is_sparse), visit_cnt(0), byte_size(0) {} - - std::string str() const noexcept - { - char buffer[1024]; - memset(buffer, 0, 1024); - sprintf(buffer, "Stage[%ld,%ld,%d]", nodes.size(), visit_cnt, is_sparse); - return std::string(buffer); - } -}; -typedef std::shared_ptr StagePtr; - -struct Pipeline { - std::vector stages; - std::size_t visit_cnt = 0; - - std::string str() const noexcept - { - std::stringstream ss; - ss << "Pipeline: " << stages.size() << " stages; visit_cnt " << visit_cnt << std::endl; - return ss.str(); - } -}; -typedef std::shared_ptr PipelinePtr; - -class ArcherTopologyHandle : public noncopyable { -public: - DELETE_COPY_AND_ASSIGN(ArcherTopologyHandle); - - ArcherTopologyHandle(); - ~ArcherTopologyHandle() = default; - - bool IsLastNode(const NodePtr& node); - bool IsFirstNode(const NodePtr& node); - - NodePtrList GetLFUNodes(const torch::Device& device); - - NodePtrList GetDenseNodes(const NodePtr& node, const std::size_t& k); - NodePtrList GetSparseNodes(const NodePtr& node, const std::size_t& k); - NodePtrList GetDenseNodes(); - NodePtrList GetSparseNodes(); - - std::uint64_t GetLastActivateStage(const HashID& hash_id); - - void InitializeTopology( - const std::vector>>>& topology); - - void EnableTrace() noexcept { trace_enabled_ = true; } - void DisableTrace() noexcept { trace_enabled_ = false; } - - std::vector> GetNodeVisitCounts(); - std::vector GetChildVisitCounts(); - void SetNodeVisitCounts(const std::vector& visit_counts); - void SetChildVisitCounts(const std::vector& visit_counts); - - NodePtr GetNodeFromTensorID(const TensorID& tensor_id); - NodeBodyPtr GetNodeBodyFromCorrID(const std::uint64_t& correlation_id); - - std::tuple GetNumLayersAndExperts(); - - std::int64_t GetSparseCacheLimit(const torch::Device& device); - - std::size_t GetNumberOfStages() const noexcept { return pipeline_.stages.size(); } - -private: - Pipeline pipeline_; - std::unordered_set visited_; - std::unordered_map last_active_stage_; - std::vector lfu_nodes_; - std::unordered_map request_time_; - std::unordered_map request_trace_; - std::int64_t visit_count_ = 0; - std::mutex mutex_; - bool trace_enabled_ = true; - - std::unordered_map tensor_id_to_node_; -}; - -extern std::unique_ptr kTopologyHandle; - -#define CONTINUE_IF_NULL(node) \ - if (node == nullptr) continue; -#define BREAK_IF_NULL(node) \ - if (node == nullptr) break; - -extern std::mutex kReadMutex; - -void SetModuleDisk(std::vector& tensor_ids); -void SetModuleMemoryFromDisk(std::vector& tensor_ids, - void* host_ptr, - bool on_demand = false); -void SetModuleCudaMemoryFromCPU(std::vector& tensor_ids, - void* device_ptr, - const torch::Device& device); -void SetModuleMemoryFromCuda(std::vector& tensor_ids, void* host_ptr); From 618c0f2542aa2b65b5688b7f9b984443c692389e Mon Sep 17 00:00:00 2001 From: xly Date: Sun, 19 Jan 2025 13:30:46 +0000 Subject: [PATCH 3/4] format --- .github/workflows/build-test.yml | 2 +- .github/workflows/publish-test.yml | 6 +- .github/workflows/publish.yml | 4 +- .github/workflows/scripts/create-release.js | 2 +- .github/workflows/scripts/cuda-install.sh | 2 +- .github/workflows/scripts/free-disk-space.sh | 4 +- .gitignore | 2 + .pre-commit-config.yaml | 52 + MANIFEST.in | 2 +- README.md | 10 +- RELEASE.md | 4 +- core/aio/archer_aio_thread.cpp | 2 - core/aio/archer_aio_utils.cpp | 11 +- core/aio/archer_prio_aio_handle.cpp | 3 +- core/aio/archer_tensor_handle.cpp | 4 +- core/aio/archer_tensor_index.h | 1 + core/memory/memory_pool.h | 14 +- core/parallel/expert_dispatcher.cpp | 75 +- core/parallel/expert_module.cpp | 94 +- core/prefetch/archer_prefetch_handle.cpp | 16 +- core/prefetch/archer_prefetch_handle.h | 2 +- core/prefetch/task_scheduler.cpp | 5 +- core/python/py_archer_prefetch.cpp | 2 +- core/utils/archer_logger.cpp | 23 +- core/utils/archer_logger.h | 43 +- core/utils/cuda_utils.cpp | 11 +- core/utils/cuda_utils.h | 8 +- environment.yml | 113 - examples/interface_example.py | 31 +- examples/readme_example.py | 13 +- moe_infinity/common/__init__.py | 2 +- moe_infinity/common/constants.py | 16 +- moe_infinity/distributed/__init__.py | 4 +- moe_infinity/distributed/devicemap_manager.py | 14 +- moe_infinity/distributed/expert_executor.py | 77 +- moe_infinity/distributed/expert_prefetcher.py | 17 +- moe_infinity/entrypoints/__init__.py | 2 +- moe_infinity/entrypoints/big_modeling.py | 51 +- moe_infinity/memory/__init__.py | 6 +- moe_infinity/memory/expert_cache.py | 37 +- moe_infinity/memory/expert_entry.py | 7 +- moe_infinity/memory/expert_predictor.py | 23 +- moe_infinity/memory/expert_prefetcher.py | 20 +- moe_infinity/memory/expert_priority_score.py | 86 +- moe_infinity/memory/expert_tracer.py | 32 +- moe_infinity/models/__init__.py | 9 +- moe_infinity/models/arctic.py | 38 +- moe_infinity/models/deepseek.py | 90 + moe_infinity/models/grok.py | 60 +- moe_infinity/models/mixtral.py | 59 +- moe_infinity/models/model_utils.py | 4 +- .../models/modeling_arctic/__init__.py | 9 +- .../modeling_arctic/configuration_arctic.py | 7 +- .../models/modeling_arctic/modeling_arctic.py | 1080 ++++++--- .../modeling_arctic/tokenization_arctic.py | 3 +- .../models/modeling_deepseek/__init__.py | 3 + .../configuration_deepseek.py | 206 ++ .../modeling_deepseek/modeling_deepseek.py | 1922 +++++++++++++++++ .../tokenization_deepseek_fast.py | 38 + .../modeling_grok/configuration_grok1.py | 4 +- .../models/modeling_grok/modeling_grok1.py | 112 +- .../modeling_grok/modeling_grok1_outputs.py | 2 +- moe_infinity/models/nllb_moe.py | 48 +- moe_infinity/models/switch_transformers.py | 30 +- moe_infinity/runtime/__init__.py | 2 +- moe_infinity/runtime/model_offload.py | 322 +-- moe_infinity/utils/__init__.py | 8 +- moe_infinity/utils/arguments.py | 8 +- moe_infinity/utils/checkpoints.py | 39 +- moe_infinity/utils/config.py | 9 +- moe_infinity/utils/hf_config.py | 27 +- op_builder/__init__.py | 28 +- op_builder/all_ops.py | 20 +- op_builder/builder.py | 398 ++-- op_builder/prefetch.py | 73 +- pyproject.toml | 31 +- requirements.txt | 21 +- setup.cfg | 2 +- setup.py | 59 +- 79 files changed, 4401 insertions(+), 1325 deletions(-) create mode 100644 .pre-commit-config.yaml delete mode 100644 environment.yml create mode 100644 moe_infinity/models/deepseek.py create mode 100644 moe_infinity/models/modeling_deepseek/__init__.py create mode 100644 moe_infinity/models/modeling_deepseek/configuration_deepseek.py create mode 100644 moe_infinity/models/modeling_deepseek/modeling_deepseek.py create mode 100644 moe_infinity/models/modeling_deepseek/tokenization_deepseek_fast.py diff --git a/.github/workflows/build-test.yml b/.github/workflows/build-test.yml index 1575a76..f8a6582 100644 --- a/.github/workflows/build-test.yml +++ b/.github/workflows/build-test.yml @@ -51,4 +51,4 @@ jobs: pip install build - name: Build package - run: BUILD_OPS=1 python -m build \ No newline at end of file + run: BUILD_OPS=1 python -m build diff --git a/.github/workflows/publish-test.yml b/.github/workflows/publish-test.yml index 95d6ae5..8e6002e 100644 --- a/.github/workflows/publish-test.yml +++ b/.github/workflows/publish-test.yml @@ -22,7 +22,7 @@ jobs: VERSION_HASH=$(date +"%Y%m%d%H%M%S") echo "Generated version hash: $VERSION_HASH" echo $VERSION_HASH > version.txt - + - name: Upload version number as artifact uses: actions/upload-artifact@v2 with: @@ -84,7 +84,7 @@ jobs: asset_name=${wheel_name//"linux"/"manylinux1"} echo "wheel_name=${wheel_name}" >> $GITHUB_ENV echo "asset_name=${asset_name}" >> $GITHUB_ENV - + # only build source when the python version is 3.8 - name: Build Source @@ -102,4 +102,4 @@ jobs: uses: pypa/gh-action-pypi-publish@release/v1.8 with: repository-url: https://test.pypi.org/legacy/ - skip-existing: true \ No newline at end of file + skip-existing: true diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index a5692b0..fd98148 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -88,7 +88,7 @@ jobs: asset_name=${wheel_name//"linux"/"manylinux1"} echo "wheel_name=${wheel_name}" >> $GITHUB_ENV echo "asset_name=${asset_name}" >> $GITHUB_ENV - + # only build source when the python version is 3.8 - name: Build Source @@ -115,4 +115,4 @@ jobs: uses: pypa/gh-action-pypi-publish@release/v1.8 with: # repository-url: https://test.pypi.org/legacy/ - skip-existing: true \ No newline at end of file + skip-existing: true diff --git a/.github/workflows/scripts/create-release.js b/.github/workflows/scripts/create-release.js index 0f25624..a42464a 100644 --- a/.github/workflows/scripts/create-release.js +++ b/.github/workflows/scripts/create-release.js @@ -17,4 +17,4 @@ module.exports = async (github, context, core) => { } catch (error) { core.setFailed(error.message); } -} \ No newline at end of file +} diff --git a/.github/workflows/scripts/cuda-install.sh b/.github/workflows/scripts/cuda-install.sh index 197c699..312c6e8 100644 --- a/.github/workflows/scripts/cuda-install.sh +++ b/.github/workflows/scripts/cuda-install.sh @@ -20,4 +20,4 @@ nvcc --version # Log gcc, g++, c++ versions gcc --version g++ --version -c++ --version \ No newline at end of file +c++ --version diff --git a/.github/workflows/scripts/free-disk-space.sh b/.github/workflows/scripts/free-disk-space.sh index 8e50f47..16b3341 100644 --- a/.github/workflows/scripts/free-disk-space.sh +++ b/.github/workflows/scripts/free-disk-space.sh @@ -20,7 +20,7 @@ # Total space: 85GB # Allocated: 67 GB # Free: 17 GB -# This script frees up 28 GB of disk space by deleting unneeded packages and +# This script frees up 28 GB of disk space by deleting unneeded packages and # large directories. # The Flink end to end tests download and generate more than 17 GB of files, # causing unpredictable behavior and build failures. @@ -45,4 +45,4 @@ echo "Removing large directories" # deleting 15GB rm -rf /usr/share/dotnet/ rm -rf /opt/hostedtoolcache/ -df -h \ No newline at end of file +df -h diff --git a/.gitignore b/.gitignore index af078f8..2b9b8b3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,8 @@ # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode +test* + ### VisualStudioCode ### .vscode/* # !.vscode/settings.json diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..e388231 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,52 @@ +repos: +- repo: meta + hooks: + - id: check-hooks-apply + - id: check-useless-excludes + +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: check-case-conflict + # - id: check-json + # - id: check-symlinks + - id: check-yaml + - id: destroyed-symlinks + - id: end-of-file-fixer + - id: fix-byte-order-marker + - id: fix-encoding-pragma + args: [--remove] + - id: mixed-line-ending + args: [--fix=lf] + - id: requirements-txt-fixer + - id: trailing-whitespace + +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.6.9 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + # args: [--check] + +- repo: https://gitlab.com/daverona/pre-commit/cpp + rev: 0.8.0 + hooks: + - id: clang-format # formatter of C/C++ code based on a style guide: LLVM, Google, Chromium, Mozilla, and WebKit available + args: ['-style=file'] + # - id: cpplint + # - id: cppcheck # exclude some checks + +- repo: https://github.com/codespell-project/codespell + rev: v2.3.0 + hooks: + - id: codespell + args: [ + # Do not check files that are automatically generated + '--skip=docs/Gemfile.lock,tests/unit/gpt2-merges.txt,tests/unit/gpt2-vocab.json', + '--ignore-regex=\\n', # Do not count the 'n' in an escaped newline as part of a word + '--ignore-words-list=youn,unsupport,noe,ccompiler', # Word used in error messages that need rewording + --check-filenames, + --check-hidden + ] diff --git a/MANIFEST.in b/MANIFEST.in index 4306d43..76302f0 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,2 @@ recursive-include core *.cpp *.h *.cc -recursive-include op_builder *.py \ No newline at end of file +recursive-include op_builder *.py diff --git a/README.md b/README.md index b89569e..0a3bc05 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ MoE-Infinity is cost-effective yet fast: - Offloading MoE's experts to host memory, allowing memory-constrained GPUs to serve MoE models. - Minimizing the expert offloading overheads through several novel techniques: expert activation tracing, activation-aware expert prefetching, and activation-aware expert caching. - Supporting LLM acceleration techniques (such as [FlashAttention](https://github.com/Dao-AILab/flash-attention)). -- Supporting multi-GPU environments with numeorous OS-level performance optimizations. +- Supporting multi-GPU environments with numeorous OS-level performance optimizations. - Achieving SOTA latency and throughput performance when serving MoEs in a resource-constrained GPU environment (in comparison with HuggingFace [Accelerate](https://github.com/huggingface/accelerate), [DeepSpeed](https://github.com/microsoft/DeepSpeed), [Mixtral-Offloading](https://github.com/dvmazur/mixtral-offloading), and [Ollama/LLama.cpp](https://github.com/ollama/ollama)). MoE-Infinity is easy-to-use: @@ -41,7 +41,7 @@ Lower per-token-latency is preferable. | MoE-Infinity | *0.230* | *0.239* | *0.895* | | Accelerate | 1.043 | 3.071 | 6.633 | |DeepSpeed | 4.578 | 8.381 | 2.486 | -|Mixtral Offloading| X | X | 1.752 | +|Mixtral Offloading| X | X | 1.752 | |Ollama | X | X | 0.903 | @@ -53,7 +53,7 @@ Higher throughput is preferable. | MoE-Infinity | *69.105* | *30.300* | *12.579* | | Accelerate | 5.788 | 4.344 | 1.245 | |DeepSpeed | 7.416 | 4.334 | 7.727 | -|Mixtral Offloading| X | X | 7.684 | +|Mixtral Offloading| X | X | 7.684 | |Ollama | X | X | 1.107 | > The Mixtral Offloading experiment was carried out with a batch size of 16, as utilizing a batch size of 32 would result in Out of Memory errors on the GPU. @@ -145,14 +145,14 @@ CUDA_VISIBLE_DEVICES=0,1 python script.py We provide a simple example to run inference on a Huggingface LLM model. The script will download the model checkpoint and run inference on the specified input text. The output will be printed to the console. ```bash -CUDA_VISIBLE_DEVICES=0 python examples/interface_example.py --model_name_or_path "mistralai/Mixtral-8x7B-Instruct-v0.1" --offload_dir +CUDA_VISIBLE_DEVICES=0 python examples/interface_example.py --model_name_or_path "mistralai/Mixtral-8x7B-Instruct-v0.1" --offload_dir ``` ## Release Plan We plan to release two functions in the following months: -* We currently support PyTorch as the default inference engine, and we are in the process of supporting vLLM as another inference runtime, which includes the support of KV cache offloading. +* We currently support PyTorch as the default inference engine, and we are in the process of supporting vLLM as another inference runtime, which includes the support of KV cache offloading. * Supporting expert parallelism for distributed MoE inference. * More (We welcome contributors to join us!) diff --git a/RELEASE.md b/RELEASE.md index 0d2a39e..7dba099 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -32,7 +32,7 @@ For developers who prefer to manually build and publish their package to PyPI, t 2. Install the required dependencies to build the package: ```bash pip install -r requirements.txt - pip install build + pip install build ``` 3. Build the source distribution and wheel for the package using: ```bash @@ -46,4 +46,4 @@ For developers who prefer to manually build and publish their package to PyPI, t Ensure that you have the necessary credentials configured for `twine` to authenticate to PyPI. -To build the package wheel for multiple Python versions, you should execute the build process individually for each version by specifying the corresponding Python interpreter. \ No newline at end of file +To build the package wheel for multiple Python versions, you should execute the build process individually for each version by specifying the corresponding Python interpreter. diff --git a/core/aio/archer_aio_thread.cpp b/core/aio/archer_aio_thread.cpp index 550eca6..e1fcd84 100644 --- a/core/aio/archer_aio_thread.cpp +++ b/core/aio/archer_aio_thread.cpp @@ -48,7 +48,6 @@ void ArcherAioThread::Wait() void ArcherAioThread::Run() { - while (is_running_) { std::function callback; { @@ -60,5 +59,4 @@ void ArcherAioThread::Run() callback(); pending_callbacks_.fetch_sub(1); } - } diff --git a/core/aio/archer_aio_utils.cpp b/core/aio/archer_aio_utils.cpp index a2eeb14..79fa8f1 100644 --- a/core/aio/archer_aio_utils.cpp +++ b/core/aio/archer_aio_utils.cpp @@ -4,10 +4,10 @@ // TorchMoE Team #include "archer_aio_utils.h" -#include -#include "utils/archer_logger.h" #include #include +#include +#include "utils/archer_logger.h" const size_t kBlockSize = 1 * 1024 * 1024; const size_t kQueueDepth = @@ -86,7 +86,7 @@ int ArcherWriteFileBatch(const int fd, const auto ret = future.get(); if (ret < 0) { ARCHER_LOG_FATAL( - "Failed to write file: ", fd,", errno: ", errno,", msg: ", strerror(errno)); + "Failed to write file: ", fd, ", errno: ", errno, ", msg: ", strerror(errno)); return -1; } } @@ -98,7 +98,8 @@ int ArcherReadFile(int fd, void* buffer, const size_t num_bytes, const size_t of { const auto ret = pread(fd, buffer, num_bytes, offset); if (ret < 0) { - ARCHER_LOG_FATAL("Failed to read file: ", fd,", errno: ", errno,", msg: ", strerror(errno)); + ARCHER_LOG_FATAL( + "Failed to read file: ", fd, ", errno: ", errno, ", msg: ", strerror(errno)); return -1; } @@ -110,7 +111,7 @@ int ArcherWriteFile(int fd, const void* buffer, size_t num_bytes, size_t offset) const auto ret = pwrite(fd, buffer, num_bytes, offset); if (ret < 0) { ARCHER_LOG_FATAL( - "Failed to write file: ", fd,", errno: ", errno,", msg: ", strerror(errno)); + "Failed to write file: ", fd, ", errno: ", errno, ", msg: ", strerror(errno)); return -1; } diff --git a/core/aio/archer_prio_aio_handle.cpp b/core/aio/archer_prio_aio_handle.cpp index 1d1d118..9f1c472 100644 --- a/core/aio/archer_prio_aio_handle.cpp +++ b/core/aio/archer_prio_aio_handle.cpp @@ -108,8 +108,7 @@ std::int64_t ArcherPrioAioHandle::Write(const std::string& filename, return num_bytes_aligned; } -ArcherPrioAioContext::ArcherPrioAioContext(const int block_size) - : block_size_(block_size) +ArcherPrioAioContext::ArcherPrioAioContext(const int block_size) : block_size_(block_size) { thread_pool_ = std::make_unique(1); // only one SSD device thread_pool_->Start(); diff --git a/core/aio/archer_tensor_handle.cpp b/core/aio/archer_tensor_handle.cpp index 2bab4df..7504026 100644 --- a/core/aio/archer_tensor_handle.cpp +++ b/core/aio/archer_tensor_handle.cpp @@ -32,7 +32,7 @@ ArcherTensorHandle::ArcherTensorHandle(const std::string& prefix) ARCHER_LOG_FATAL("Invalid prefix: ", prefix_, " is not a directory"); } if (stat(prefix_.c_str(), &st) == -1) { - ARCHER_LOG_WARN("Invalid prefix: ", prefix_," does not exist, creating"); + ARCHER_LOG_WARN("Invalid prefix: ", prefix_, " does not exist, creating"); mkdir(prefix_.c_str(), 0777); } @@ -44,7 +44,7 @@ ArcherTensorHandle::ArcherTensorHandle(const std::string& prefix) kTensorIndex->Deserialize(ckpt_index_path.c_str()); is_serialized_ = true; } else { - ARCHER_LOG_INFO("Index file", ckpt_index_path," does not exist, creating"); + ARCHER_LOG_INFO("Index file", ckpt_index_path, " does not exist, creating"); } ARCHER_LOG_INFO("Index file size ", kTensorIndex->size()); } diff --git a/core/aio/archer_tensor_index.h b/core/aio/archer_tensor_index.h index 59a5657..d42017b 100644 --- a/core/aio/archer_tensor_index.h +++ b/core/aio/archer_tensor_index.h @@ -47,6 +47,7 @@ class ArcherTensorIndex : public std::unordered_map ArcherTensorIndex() = default; ~ArcherTensorIndex() = default; + private: }; diff --git a/core/memory/memory_pool.h b/core/memory/memory_pool.h index 02424c3..1352de8 100644 --- a/core/memory/memory_pool.h +++ b/core/memory/memory_pool.h @@ -7,12 +7,12 @@ #include "common/pytorch.h" #include "utils/noncopyable.h" -#include "utils/archer_logger.h" -#include "host_caching_allocator.h" #include #include #include #include +#include "host_caching_allocator.h" +#include "utils/archer_logger.h" std::size_t GetTotalSystemMemory(); @@ -41,9 +41,7 @@ class HostMemoryPool : public noncopyable { { auto allocator = c10::HostCachingAllocator::get(); for (auto& [key, data_ptr] : allocated_id_) { - if (data_ptr != nullptr) { - allocator->free(data_ptr); - } + if (data_ptr != nullptr) { allocator->free(data_ptr); } } allocated_id_.clear(); } @@ -73,11 +71,9 @@ class DeviceMemoryPool : public noncopyable { virtual ~DeviceMemoryPool() { auto allocator = c10::cuda::CUDACachingAllocator::get(); - for(auto &allocated_id : allocated_id_){ + for (auto& allocated_id : allocated_id_) { for (auto& [key, data_ptr] : allocated_id) { - if (data_ptr != nullptr) { - allocator->raw_deallocate(data_ptr); - } + if (data_ptr != nullptr) { allocator->raw_deallocate(data_ptr); } } } allocated_id_.clear(); diff --git a/core/parallel/expert_dispatcher.cpp b/core/parallel/expert_dispatcher.cpp index 6a20266..5b0ed04 100644 --- a/core/parallel/expert_dispatcher.cpp +++ b/core/parallel/expert_dispatcher.cpp @@ -120,13 +120,18 @@ void ExpertDispatcher::Enqueue(const CallArgs& args) auto& a = input_queue_.back(); if (expert_node->node->device.is_cuda()) { a.gpu_id = expert_node->node->device.index(); } - ARCHER_LOG_DEBUG( - "ExpertDispatcher::Enqueue: num_enqueued_ ", num_enqueued_.load(), - "input_queue_ ", input_queue_.size(), \ - "gpu_id ", a.gpu_id, - "layer_idx ", a.layer_idx, - "expert_idx ", a.expert_idx, - "remote ", a.remote); + ARCHER_LOG_DEBUG("ExpertDispatcher::Enqueue: num_enqueued_ ", + num_enqueued_.load(), + "input_queue_ ", + input_queue_.size(), + "gpu_id ", + a.gpu_id, + "layer_idx ", + a.layer_idx, + "expert_idx ", + a.expert_idx, + "remote ", + a.remote); } void ExpertDispatcher::RegisterExpert(int layer_idx, @@ -220,10 +225,8 @@ void ExpertDispatcher::GPUFetchFunc(int gpu_id) wait_count++; // if (wait_count % 100000 == 0) { // ARCHER_LOG_WARN( - // "ExpertDispatcher::EnqueueTask: gpu_overload_ gpu_id {} wait_count {} {}", - // gpu_id, - // wait_count, - // expert_node->node->str()); + // "ExpertDispatcher::EnqueueTask: gpu_overload_ gpu_id {} wait_count {} + // {}", gpu_id, wait_count, expert_node->node->str()); // } } @@ -277,12 +280,16 @@ void ExpertDispatcher::GPUFetchFunc(int gpu_id) expert_type); } - ARCHER_LOG_DEBUG( - "ExpertDispatcher::GPUFetchFunc gpu_id ", gpu_id, - "layer_idx ", layer_idx, - "expert_idx ", expert_idx, - "input ", input.device().str(), - "node ", expert_node->node->device.str()); + ARCHER_LOG_DEBUG("ExpertDispatcher::GPUFetchFunc gpu_id ", + gpu_id, + "layer_idx ", + layer_idx, + "expert_idx ", + expert_idx, + "input ", + input.device().str(), + "node ", + expert_node->node->device.str()); { ExecArgs exec_args; exec_args.hidden_states = std::move(input); @@ -332,7 +339,7 @@ void ExpertDispatcher::GPUExecFunc(int gpu_id) auto* expert_module = args.expert_node->module; int expert_type = expert_type_; - cudaStreamSynchronize(0); // make sure the input is ready + cudaStreamSynchronize(0); // make sure the input is ready try { switch (expert_type) { @@ -358,9 +365,8 @@ void ExpertDispatcher::GPUExecFunc(int gpu_id) ->forward(args.hidden_states); break; default: - ARCHER_LOG_FATAL( - "ExpertDispatcher::ExpertDispatcher: unknown expert type", - expert_type); + ARCHER_LOG_FATAL("ExpertDispatcher::ExpertDispatcher: unknown expert type", + expert_type); } } catch (const std::exception& e) { @@ -368,7 +374,11 @@ void ExpertDispatcher::GPUExecFunc(int gpu_id) ss << "DenseActDense tensor_ids: ["; for (auto& id : args.expert_node->node->tensor_ids) { ss << id << " "; } ss << "]"; - ARCHER_LOG_FATAL("ExpertDispatcher::GPUExecFunc", ss.str(), "expert_type", expert_type, e.what()); + ARCHER_LOG_FATAL("ExpertDispatcher::GPUExecFunc", + ss.str(), + "expert_type", + expert_type, + e.what()); } stream.synchronize(); @@ -407,15 +417,18 @@ void ExpertDispatcher::OutputFunc(ExecArgs args, torch::Tensor output, int gpu_i args.expert_node->layer_idx, args.expert_node->expert_idx, args.hit); - ARCHER_LOG_DEBUG( - "ExpertDispatcher::OutputFunc: output_queue_", output_queue_.size(), - "output", std::get<0>(output_queue_.back()).device().str(), - "evict", args.evict, - "(", - args.expert_node->layer_idx, - args.expert_node->expert_idx, - gpu_id, - args.hit, ")"); + ARCHER_LOG_DEBUG("ExpertDispatcher::OutputFunc: output_queue_", + output_queue_.size(), + "output", + std::get<0>(output_queue_.back()).device().str(), + "evict", + args.evict, + "(", + args.expert_node->layer_idx, + args.expert_node->expert_idx, + gpu_id, + args.hit, + ")"); } stream.synchronize(); pending_.fetch_sub(1); diff --git a/core/parallel/expert_module.cpp b/core/parallel/expert_module.cpp index b301002..254cc28 100644 --- a/core/parallel/expert_module.cpp +++ b/core/parallel/expert_module.cpp @@ -16,9 +16,9 @@ SwitchTransformersDenseActDense::SwitchTransformersDenseActDense(int dtype) } void SwitchTransformersDenseActDense::SetTensorsFromBlob( - void *ptr, - const std::vector &tensor_ids, - const torch::Device &device) + void* ptr, + const std::vector& tensor_ids, + const torch::Device& device) { wi = kTensorIndex->find(tensor_ids[0])->second.tensor; wo = kTensorIndex->find(tensor_ids[1])->second.tensor; @@ -46,9 +46,9 @@ SwitchTransformersDenseGatedActDense::SwitchTransformersDenseGatedActDense(int d } void SwitchTransformersDenseGatedActDense::SetTensorsFromBlob( - void *ptr, - const std::vector &tensor_ids, - const torch::Device &device) + void* ptr, + const std::vector& tensor_ids, + const torch::Device& device) { wi_0 = kTensorIndex->find(tensor_ids[0])->second.tensor; wi_1 = kTensorIndex->find(tensor_ids[1])->second.tensor; @@ -72,9 +72,9 @@ NllbMoeDenseActDense::NllbMoeDenseActDense(int dtype) fc2_bias = register_parameter("fc2_bias", torch::zeros({1}, options)); } -void NllbMoeDenseActDense::SetTensorsFromBlob(void *ptr, - const std::vector &tensor_ids, - const torch::Device &device) +void NllbMoeDenseActDense::SetTensorsFromBlob(void* ptr, + const std::vector& tensor_ids, + const torch::Device& device) { fc1 = kTensorIndex->find(tensor_ids[0])->second.tensor; fc1_bias = kTensorIndex->find(tensor_ids[1])->second.tensor; @@ -84,7 +84,8 @@ void NllbMoeDenseActDense::SetTensorsFromBlob(void *ptr, torch::Tensor NllbMoeDenseActDense::forward(torch::Tensor hidden_states) { - // ARCHER_LOG_DEBUG("NllbMoeDenseActDense fc1 {} fc1_bias {} fc2 {} fc2_bias {} hidden_states {}", + // ARCHER_LOG_DEBUG("NllbMoeDenseActDense fc1 {} fc1_bias {} fc2 {} fc2_bias {} hidden_states + // {}", // fc1.device().str(), // fc1_bias.device().str(), // fc2.device().str(), @@ -105,9 +106,9 @@ FSGPTMoEDenseActDense::FSGPTMoEDenseActDense(int dtype) fc2_bias = register_parameter("fc2_bias", torch::zeros({1}, options)); } -void FSGPTMoEDenseActDense::SetTensorsFromBlob(void *ptr, - const std::vector &tensor_ids, - const torch::Device &device) +void FSGPTMoEDenseActDense::SetTensorsFromBlob(void* ptr, + const std::vector& tensor_ids, + const torch::Device& device) { fc1 = kTensorIndex->find(tensor_ids[0])->second.tensor; fc1_bias = kTensorIndex->find(tensor_ids[1])->second.tensor; @@ -117,14 +118,14 @@ void FSGPTMoEDenseActDense::SetTensorsFromBlob(void *ptr, torch::Tensor FSGPTMoEDenseActDense::forward(torch::Tensor hidden_states) { - // ARCHER_LOG_DEBUG("FSGPTMoEDenseActDense fc1 {} fc1_bias {} fc2 {} fc2_bias {} hidden_states {}", + // ARCHER_LOG_DEBUG("FSGPTMoEDenseActDense fc1 {} fc1_bias {} fc2 {} fc2_bias {} hidden_states + // {}", // fc1.device().str(), // fc1_bias.device().str(), // fc2.device().str(), // fc2_bias.device().str(), // hidden_states.device().str()); - if (hidden_states.dtype() != fc1.dtype()) - hidden_states = hidden_states.to(fc1.dtype()); + if (hidden_states.dtype() != fc1.dtype()) hidden_states = hidden_states.to(fc1.dtype()); return torch::matmul(torch::relu(torch::matmul(hidden_states, fc1.transpose(0, 1)) + fc1_bias), fc2.transpose(0, 1)) + fc2_bias; @@ -139,9 +140,9 @@ MixtralMoEDenseActDense::MixtralMoEDenseActDense(int dtype) w3 = register_parameter("w3", torch::zeros({1}, options)); } -void MixtralMoEDenseActDense::SetTensorsFromBlob(void *ptr, - const std::vector &tensor_ids, - const torch::Device &device) +void MixtralMoEDenseActDense::SetTensorsFromBlob(void* ptr, + const std::vector& tensor_ids, + const torch::Device& device) { w1 = kTensorIndex->find(tensor_ids[0])->second.tensor; w2 = kTensorIndex->find(tensor_ids[1])->second.tensor; @@ -159,42 +160,43 @@ torch::Tensor MixtralMoEDenseActDense::forward(torch::Tensor hidden_states) int w2_nan = torch::sum(torch::isnan(w2)).item(); int w3_nan = torch::sum(torch::isnan(w3)).item(); int hidden_states_nan = torch::sum(torch::isnan(hidden_states)).item(); - // std::cout << "MixtralMoEDenseActDense w1 " << w1_nan << " w2 " << w2_nan << " w3 " << w3_nan << " hidden_states " << hidden_states_nan << std::endl; + // std::cout << "MixtralMoEDenseActDense w1 " << w1_nan << " w2 " << w2_nan << " w3 " << w3_nan + // << " hidden_states " << hidden_states_nan << std::endl; assert(w1_nan == 0); assert(w2_nan == 0); assert(w3_nan == 0); assert(hidden_states_nan == 0); - return torch::matmul(torch::silu(torch::matmul(hidden_states, w1.transpose(0, 1))) * torch::matmul(hidden_states, w3.transpose(0, 1)), w2.transpose(0, 1)); + return torch::matmul(torch::silu(torch::matmul(hidden_states, w1.transpose(0, 1))) * + torch::matmul(hidden_states, w3.transpose(0, 1)), + w2.transpose(0, 1)); } -void ExpertNode::SetTensorsFromBlob(const torch::Device &device) +void ExpertNode::SetTensorsFromBlob(const torch::Device& device) { int expert_type = this->expert_type; - switch (expert_type) - { - case SWITCH_TRANSFORMERS_DENSE_ACT_DENSE: - reinterpret_cast(module)->SetTensorsFromBlob( - node->device_memory_ptr, node->tensor_ids, device); - break; - case SWITCH_TRANSFORMERS_DENSE_GATED_ACT_DENSE: - reinterpret_cast(module)->SetTensorsFromBlob( - node->device_memory_ptr, node->tensor_ids, device); - break; - case NLLB_MOE_DENSE_ACT_DENSE: - reinterpret_cast(module)->SetTensorsFromBlob( - node->device_memory_ptr, node->tensor_ids, device); - break; - case FSGPT_MOE_DENSE_ACT_DENSE: - reinterpret_cast(module)->SetTensorsFromBlob( - node->device_memory_ptr, node->tensor_ids, device); - break; - case MIXTRAL_MOE_DENSE_ACT_DENSE: - reinterpret_cast(module)->SetTensorsFromBlob( - node->device_memory_ptr, node->tensor_ids, device); - break; - default: - assert(false); + switch (expert_type) { + case SWITCH_TRANSFORMERS_DENSE_ACT_DENSE: + reinterpret_cast(module)->SetTensorsFromBlob( + node->device_memory_ptr, node->tensor_ids, device); + break; + case SWITCH_TRANSFORMERS_DENSE_GATED_ACT_DENSE: + reinterpret_cast(module)->SetTensorsFromBlob( + node->device_memory_ptr, node->tensor_ids, device); + break; + case NLLB_MOE_DENSE_ACT_DENSE: + reinterpret_cast(module)->SetTensorsFromBlob( + node->device_memory_ptr, node->tensor_ids, device); + break; + case FSGPT_MOE_DENSE_ACT_DENSE: + reinterpret_cast(module)->SetTensorsFromBlob( + node->device_memory_ptr, node->tensor_ids, device); + break; + case MIXTRAL_MOE_DENSE_ACT_DENSE: + reinterpret_cast(module)->SetTensorsFromBlob( + node->device_memory_ptr, node->tensor_ids, device); + break; + default: assert(false); } } diff --git a/core/prefetch/archer_prefetch_handle.cpp b/core/prefetch/archer_prefetch_handle.cpp index 7e1ddb1..876d4b4 100644 --- a/core/prefetch/archer_prefetch_handle.cpp +++ b/core/prefetch/archer_prefetch_handle.cpp @@ -12,8 +12,8 @@ #include "common/time.h" #include "memory/memory_pool.h" #include "task_scheduler.h" -#include "utils/cuda_utils.h" #include "utils/archer_logger.h" +#include "utils/cuda_utils.h" ArcherPrefetchHandle::ArcherPrefetchHandle(const std::string& prefix, const double device_memory_ratio) @@ -45,29 +45,26 @@ ArcherPrefetchHandle::ArcherPrefetchHandle(const std::string& prefix, if (can_access == 1) { cudaSetDevice(i); cudaError_t status = cudaDeviceEnablePeerAccess(j, 0); - if (status == cudaErrorPeerAccessAlreadyEnabled){ + if (status == cudaErrorPeerAccessAlreadyEnabled) { ARCHER_LOG_INFO("Peer access already enabled between device ", i, j); - cudaGetLastError(); // clear error + cudaGetLastError(); // clear error } else if (status != cudaSuccess) { ARCHER_LOG_ERROR("Failed to enable peer access between device ", i, j); } else { ARCHER_LOG_INFO("Enabled peer access between device ", i, j); } - } } } } - + ARCHER_LOG_INFO("Enabled peer access for all devices"); } ArcherPrefetchHandle::~ArcherPrefetchHandle() { // served as a global manager for order of destruction - if(!has_cleaned_up_resources_) { - CleanUpResources(); - } + if (!has_cleaned_up_resources_) { CleanUpResources(); } } void ArcherPrefetchHandle::CleanUpResources() @@ -393,7 +390,8 @@ int ArcherPrefetchHandle::GetNodeDevice(std::vector tensor_ids) c return node->device.index(); } -// void ArcherPrefetchHandle::SetNodeCachePriority(const std::uint32_t tensor_id, const float priority) { +// void ArcherPrefetchHandle::SetNodeCachePriority(const std::uint32_t tensor_id, const float +// priority) { // auto node = kTopologyHandle->GetNodeFromTensorID(tensor_id); // node->cache_priority = priority; // } diff --git a/core/prefetch/archer_prefetch_handle.h b/core/prefetch/archer_prefetch_handle.h index d5ba6aa..6f6dd57 100644 --- a/core/prefetch/archer_prefetch_handle.h +++ b/core/prefetch/archer_prefetch_handle.h @@ -6,8 +6,8 @@ #pragma once #include "aio/archer_tensor_handle.h" -#include "parallel/expert_dispatcher.h" #include "model/model_topology.h" +#include "parallel/expert_dispatcher.h" class ArcherPrefetchHandle { public: diff --git a/core/prefetch/task_scheduler.cpp b/core/prefetch/task_scheduler.cpp index 9f48fb4..b3a1a8e 100644 --- a/core/prefetch/task_scheduler.cpp +++ b/core/prefetch/task_scheduler.cpp @@ -382,9 +382,8 @@ bool ArcherTaskPool::RemoveCachedDenseNode(const NodePtr& node) // auto device_id = node->default_device.index(); -// auto cache_limit = kDeviceMemoryPool->GetMemoryCapacity(torch::Device(torch::kCUDA, device_id)); -// cache_limit -= node->byte_size; -// int64_t cache_size = 0; +// auto cache_limit = kDeviceMemoryPool->GetMemoryCapacity(torch::Device(torch::kCUDA, +// device_id)); cache_limit -= node->byte_size; int64_t cache_size = 0; // NodePtrList device_nodes; // for (auto& n : nodes) { diff --git a/core/python/py_archer_prefetch.cpp b/core/python/py_archer_prefetch.cpp index 65c3580..f9737d2 100644 --- a/core/python/py_archer_prefetch.cpp +++ b/core/python/py_archer_prefetch.cpp @@ -70,7 +70,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) .def("enqueue_prefetch", &ArcherPrefetchHandle::EnqueuePrefetch) .def("fetch_tensors", &ArcherPrefetchHandle::FetchTensors) .def("clean_up_resources", &ArcherPrefetchHandle::CleanUpResources); - // .def("set_node_cache_priority", &ArcherPrefetchHandle::SetNodeCachePriority); + // .def("set_node_cache_priority", &ArcherPrefetchHandle::SetNodeCachePriority); py::class_(m, "expert_dispatcher") .def(py::init()) diff --git a/core/utils/archer_logger.cpp b/core/utils/archer_logger.cpp index b2a776e..6ab324a 100644 --- a/core/utils/archer_logger.cpp +++ b/core/utils/archer_logger.cpp @@ -6,14 +6,13 @@ #include "archer_logger.h" #include -#include #include +#include std::once_flag kLoggerFlag; int kLogLevel = -1; std::mutex kLogMutex; - int str2level(const char* level) { if (strcmp(level, "info") == 0) { @@ -34,18 +33,12 @@ int str2level(const char* level) std::string level2str(int level) { switch (level) { - case kInfo: - return "INFO"; - case kError: - return "ERROR"; - case kWarn: - return "WARN"; - case kDebug: - return "DEBUG"; - case kFatal: - return "FATAL"; - default: - return "UNKNOWN"; + case kInfo: return "INFO"; + case kError: return "ERROR"; + case kWarn: return "WARN"; + case kDebug: return "DEBUG"; + case kFatal: return "FATAL"; + default: return "UNKNOWN"; } } @@ -56,7 +49,7 @@ std::string formatstr() auto ms = std::chrono::duration_cast(time.time_since_epoch()) % 1000; auto timer = std::chrono::system_clock::to_time_t(time); auto tm = *std::localtime(&timer); - + auto year = tm.tm_year + 1900; auto month = tm.tm_mon + 1; auto day = tm.tm_mday; diff --git a/core/utils/archer_logger.h b/core/utils/archer_logger.h index fc2bb62..446e80f 100644 --- a/core/utils/archer_logger.h +++ b/core/utils/archer_logger.h @@ -43,13 +43,13 @@ extern std::mutex kLogMutex; extern void InitLogger(); -#define ARCHER_LOG_DEBUG(...) \ - do { \ - if (kLogLevel <= kDebug) { \ - std::lock_guard lock(kLogMutex); \ +#define ARCHER_LOG_DEBUG(...) \ + do { \ + if (kLogLevel <= kDebug) { \ + std::lock_guard lock(kLogMutex); \ std::cout << formatstr() << level2str(kDebug) << " "; \ - print(__VA_ARGS__); \ - } \ + print(__VA_ARGS__); \ + } \ } while (0) #define ARCHER_LOG_INFO(...) \ @@ -57,17 +57,17 @@ extern void InitLogger(); if (kLogLevel <= kInfo) { \ std::lock_guard lock(kLogMutex); \ std::cout << formatstr() << level2str(kInfo) << " "; \ - print(__VA_ARGS__); \ + print(__VA_ARGS__); \ } \ } while (0) -#define ARCHER_LOG_ERROR(...) \ - do { \ - if (kLogLevel <= kError) { \ - std::lock_guard lock(kLogMutex); \ +#define ARCHER_LOG_ERROR(...) \ + do { \ + if (kLogLevel <= kError) { \ + std::lock_guard lock(kLogMutex); \ std::cout << formatstr() << level2str(kError) << " "; \ - print(__VA_ARGS__); \ - } \ + print(__VA_ARGS__); \ + } \ } while (0) #define ARCHER_LOG_WARN(...) \ @@ -75,17 +75,16 @@ extern void InitLogger(); if (kLogLevel <= kWarn) { \ std::lock_guard lock(kLogMutex); \ std::cout << formatstr() << level2str(kWarn) << " "; \ - print(__VA_ARGS__); \ + print(__VA_ARGS__); \ } \ } while (0) -#define ARCHER_LOG_FATAL(...) \ - do { \ - if (kLogLevel <= kError) { \ - std::lock_guard lock(kLogMutex); \ +#define ARCHER_LOG_FATAL(...) \ + do { \ + if (kLogLevel <= kError) { \ + std::lock_guard lock(kLogMutex); \ std::cout << formatstr() << level2str(kFatal) << " "; \ - print(__VA_ARGS__); \ - throw std::runtime_error("Logged a FATAL error"); \ - } \ + print(__VA_ARGS__); \ + throw std::runtime_error("Logged a FATAL error"); \ + } \ } while (0) - diff --git a/core/utils/cuda_utils.cpp b/core/utils/cuda_utils.cpp index 0daa473..ace8b97 100644 --- a/core/utils/cuda_utils.cpp +++ b/core/utils/cuda_utils.cpp @@ -46,11 +46,10 @@ int CudaMemcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind) } int CudaMemcpyAsync(void* dst, - const void* src, - size_t count, - cudaMemcpyKind kind, - cudaStream_t stream) + const void* src, + size_t count, + cudaMemcpyKind kind, + cudaStream_t stream) { - return cudaMemcpyAsync( - dst, src, count, kind, stream); + return cudaMemcpyAsync(dst, src, count, kind, stream); } diff --git a/core/utils/cuda_utils.h b/core/utils/cuda_utils.h index 0bb1cc5..f8ffd28 100644 --- a/core/utils/cuda_utils.h +++ b/core/utils/cuda_utils.h @@ -18,7 +18,7 @@ std::size_t GetFreeDeviceMemory(int device_id); int CudaMemcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind); int CudaMemcpyAsync(void* dst, - const void* src, - size_t count, - cudaMemcpyKind kind, - cudaStream_t stream = 0); + const void* src, + size_t count, + cudaMemcpyKind kind, + cudaStream_t stream = 0); diff --git a/environment.yml b/environment.yml deleted file mode 100644 index a8f9846..0000000 --- a/environment.yml +++ /dev/null @@ -1,113 +0,0 @@ -name: moe-infinity -channels: - - conda-forge - - https://public.dhe.ibm.com/ibmdl/export/pub/software/server/ibm-ai/conda/ - - defaults -dependencies: - - _libgcc_mutex=0.1=main - - _openmp_mutex=5.1=1_gnu - - ca-certificates=2024.2.2=hbcca054_0 - - ld_impl_linux-64=2.38=h1181459_1 - - libffi=3.4.4=h6a678d5_0 - - libgcc-ng=11.2.0=h1234567_1 - - libgomp=11.2.0=h1234567_1 - - libstdcxx-ng=12.3.0=h0f45ef3_5 - - ncurses=6.4=h6a678d5_0 - - openssl=3.0.13=h7f8727e_0 - - pip=23.3.1=py39h06a4308_0 - - python=3.9.18=h955ad1f_0 - - readline=8.2=h5eee18b_0 - - setuptools=68.2.2=py39h06a4308_0 - - sqlite=3.41.2=h5eee18b_0 - - tk=8.6.12=h1ccaba5_0 - - wheel=0.41.2=py39h06a4308_0 - - xz=5.4.6=h5eee18b_0 - - zlib=1.2.13=h5eee18b_0 - - pip: - - accelerate==0.27.2 - - aiohttp==3.9.3 - - aiosignal==1.3.1 - - alabaster==0.7.16 - - async-timeout==4.0.3 - - attrs==23.2.0 - - auto-gptq==0.7.0 - - babel==2.14.0 - - certifi==2024.2.2 - - chardet==5.2.0 - - charset-normalizer==3.3.2 - - coloredlogs==15.0.1 - - datasets==2.17.1 - - dill==0.3.8 - - docutils==0.20.1 - - filelock==3.13.1 - - frozenlist==1.4.1 - - fsspec==2023.10.0 - - gekko==1.0.6 - - hjson==3.1.0 - - huggingface-hub==0.21.1 - - humanfriendly==10.0 - - idna==3.6 - - imagesize==1.4.1 - - importlib-metadata==7.0.1 - - jinja2==3.1.3 - - markupsafe==2.1.5 - - moe-infinity==0.0.1.dev4 - - mpmath==1.3.0 - - multidict==6.0.5 - - multiprocess==0.70.16 - - networkx==3.2.1 - - ninja==1.11.1.1 - - numpy==1.26.4 - - nvidia-cublas-cu12==12.1.3.1 - - nvidia-cuda-cupti-cu12==12.1.105 - - nvidia-cuda-nvrtc-cu12==12.1.105 - - nvidia-cuda-runtime-cu12==12.1.105 - - nvidia-cudnn-cu12==8.9.2.26 - - nvidia-cufft-cu12==11.0.2.54 - - nvidia-curand-cu12==10.3.2.106 - - nvidia-cusolver-cu12==11.4.5.107 - - nvidia-cusparse-cu12==12.1.0.106 - - nvidia-nccl-cu12==2.19.3 - - nvidia-nvjitlink-cu12==12.3.101 - - nvidia-nvtx-cu12==12.1.105 - - optimum==1.17.1 - - packaging==23.2 - - pandas==2.2.1 - - peft==0.9.0 - - protobuf==4.25.3 - - psutil==5.9.8 - - py-cpuinfo==9.0.0 - - pyarrow==12.0.0 - - pyarrow-hotfix==0.6 - - pydantic==1.10.12 - - pygments==2.17.2 - - python-dateutil==2.8.2 - - pytz==2024.1 - - pyyaml==6.0.1 - - regex==2023.12.25 - - requests==2.31.0 - - rouge==1.0.1 - - safetensors==0.4.2 - - scipy==1.12.0 - - sentencepiece==0.2.0 - - six==1.16.0 - - snowballstemmer==2.2.0 - - sphinx==7.2.6 - - sphinxcontrib-applehelp==1.0.8 - - sphinxcontrib-devhelp==1.0.6 - - sphinxcontrib-htmlhelp==2.0.5 - - sphinxcontrib-jsmath==1.0.1 - - sphinxcontrib-qthelp==1.0.7 - - sphinxcontrib-serializinghtml==1.1.10 - - sympy==1.12 - - tokenizers==0.15.2 - - torch==2.2.1 - - tqdm==4.66.2 - - transformers==4.38.1 - - triton==2.2.0 - - typing-extensions==4.10.0 - - tzdata==2024.1 - - urllib3==2.2.1 - - xxhash==3.4.1 - - yarl==1.9.4 - - zipp==3.17.0 diff --git a/examples/interface_example.py b/examples/interface_example.py index 20bde32..fd59474 100644 --- a/examples/interface_example.py +++ b/examples/interface_example.py @@ -3,15 +3,17 @@ # TorchMoE Team -from functools import partial -import os -import torch import argparse -import datasets import multiprocessing as mp -from transformers import AutoTokenizer, TextStreamer, LlamaTokenizerFast +import os +from functools import partial + +import datasets +import torch +from transformers import AutoTokenizer, LlamaTokenizerFast, TextStreamer + from moe_infinity import MoE -from moe_infinity.models.arctic import ArcticTokenizer +from moe_infinity.models.modeling_arctic import ArcticTokenizer parser = argparse.ArgumentParser() parser.add_argument("--model_name_or_path", type=str, required=True) @@ -23,11 +25,15 @@ tokenizer = None if "grok" in model_name: - tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer", trust_remote_code=True) + tokenizer = LlamaTokenizerFast.from_pretrained( + "Xenova/grok-1-tokenizer", trust_remote_code=True + ) elif "arctic" in args.model_name_or_path.lower(): tokenizer = ArcticTokenizer.from_pretrained(args.model_name_or_path) else: - tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + args.model_name_or_path, trust_remote_code=True, use_fast=False + ) streamer = TextStreamer(tokenizer) dataset_name = "tasksource/bigbench" @@ -42,7 +48,9 @@ all_inputs = [None] * len(names) all_inputs = pool.map(partial(datasets.load_dataset, dataset_name), names) -all_inputs = [text for dataset in all_inputs for text in dataset["validation"]["inputs"] ] +all_inputs = [ + text for dataset in all_inputs for text in dataset["validation"]["inputs"] +] config = { "offload_path": os.path.join(args.offload_dir, model_name), @@ -55,19 +63,20 @@ if "switch" in args.model_name_or_path.lower(): custom_kwargs = {"decoder_start_token_id": 0} elif "nllb" in args.model_name_or_path.lower(): - custom_kwargs = {"forced_bos_token_id": 256057} # translate to French + custom_kwargs = {"forced_bos_token_id": 256057} # translate to French elif "mixtral" in args.model_name_or_path.lower(): custom_kwargs = {"pad_token_id": tokenizer.eos_token_id} elif "grok" in args.model_name_or_path.lower(): custom_kwargs = {} elif "arctic" in args.model_name_or_path.lower(): custom_kwargs = {"pad_token_id": tokenizer.eos_token_id} +elif "deepseek" in args.model_name_or_path.lower(): + custom_kwargs = {"pad_token_id": tokenizer.eos_token_id} else: raise ValueError(f"Model {args.model_name_or_path} not supported") tokenizer.pad_token = tokenizer.eos_token for input_text in all_inputs: - inputs = tokenizer( input_text, truncation=True, diff --git a/examples/readme_example.py b/examples/readme_example.py index 2f1129c..73749f8 100644 --- a/examples/readme_example.py +++ b/examples/readme_example.py @@ -1,16 +1,17 @@ -import torch import os -from transformers import AutoTokenizer, SwitchTransformersForConditionalGeneration + +from transformers import AutoTokenizer + from moe_infinity import MoE -user_home = os.path.expanduser('~') +user_home = os.path.expanduser("~") -checkpoint = 'TheBloke/Mixtral-8x7B-v0.1-GPTQ' +checkpoint = "TheBloke/Mixtral-8x7B-v0.1-GPTQ" tokenizer = AutoTokenizer.from_pretrained(checkpoint) config = { "offload_path": os.path.join(user_home, "moe-infinity"), - "device_memory_ratio": 0.75, # 75% of the device memory is used for caching, change the value according to your device memory size on OOM + "device_memory_ratio": 0.75, # 75% of the device memory is used for caching, change the value according to your device memory size on OOM } model = MoE(checkpoint, config) @@ -21,4 +22,4 @@ output_ids = model.generate(input_ids) output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) -print(output_text) \ No newline at end of file +print(output_text) diff --git a/moe_infinity/common/__init__.py b/moe_infinity/common/__init__.py index 9737af3..ee3f060 100644 --- a/moe_infinity/common/__init__.py +++ b/moe_infinity/common/__init__.py @@ -1 +1 @@ -from .constants import * \ No newline at end of file +from .constants import * diff --git a/moe_infinity/common/constants.py b/moe_infinity/common/constants.py index 60d21df..f6792b9 100644 --- a/moe_infinity/common/constants.py +++ b/moe_infinity/common/constants.py @@ -1,13 +1,18 @@ from transformers import ( - SwitchTransformersForConditionalGeneration, - NllbMoeForConditionalGeneration, MixtralForCausalLM, + NllbMoeForConditionalGeneration, OPTForCausalLM, PretrainedConfig, + SwitchTransformersForConditionalGeneration, ) -from ..models.modeling_grok.modeling_grok1 import Grok1ModelForCausalLM # TODO: Replace this with huggingface transformers -from ..models.modeling_arctic import ArcticForCausalLM # TODO: Replace this with huggingface transformers +from ..models.modeling_arctic import ( + ArcticForCausalLM, +) # TODO: Replace this with huggingface transformers +from ..models.modeling_deepseek import DeepseekV2ForCausalLM +from ..models.modeling_grok.modeling_grok1 import ( + Grok1ModelForCausalLM, +) # TODO: Replace this with huggingface transformers MODEL_MAPPING_NAMES = { "switch": SwitchTransformersForConditionalGeneration, @@ -16,6 +21,7 @@ "opt": OPTForCausalLM, "grok": Grok1ModelForCausalLM, "arctic": ArcticForCausalLM, + "deepseek": DeepseekV2ForCausalLM, } MODEL_MAPPING_TYPES = { @@ -24,8 +30,10 @@ "mixtral": 4, "grok": 4, "arctic": 4, + "deepseek": 4, } + def parse_expert_type(config: PretrainedConfig) -> int: architecture = config.architectures[0].lower() arch = None diff --git a/moe_infinity/distributed/__init__.py b/moe_infinity/distributed/__init__.py index 60e22b3..9a34853 100644 --- a/moe_infinity/distributed/__init__.py +++ b/moe_infinity/distributed/__init__.py @@ -3,6 +3,6 @@ # TorchMoE Team -from .expert_prefetcher import DistributedExpertPrefetcher -from .expert_executor import DistributedExpertExecutor from .devicemap_manager import DeviceMapManager +from .expert_executor import DistributedExpertExecutor +from .expert_prefetcher import DistributedExpertPrefetcher diff --git a/moe_infinity/distributed/devicemap_manager.py b/moe_infinity/distributed/devicemap_manager.py index b7562fb..2f008f5 100644 --- a/moe_infinity/distributed/devicemap_manager.py +++ b/moe_infinity/distributed/devicemap_manager.py @@ -4,12 +4,14 @@ # TorchMoE Team # The global device manager shared among all nodes, using grpc server to communicate with each other. -from typing import Tuple, List -import numpy as np import random -from moe_infinity.utils import ArcherConfig +from typing import List, Tuple + import torch.distributed as dist +from moe_infinity.utils import ArcherConfig + + class DeviceMapManager: def __init__(self, archer_config: ArcherConfig) -> None: world_size = dist.get_world_size() @@ -33,7 +35,9 @@ def set_expert_tensor_map(self, expert_tensor_map): def set_archer_engine(self, archer_engine): self.archer_engine = archer_engine - def get_target_device(self, expert_list: List[int]) -> List[Tuple[int, int, int]]: + def get_target_device( + self, expert_list: List[int] + ) -> List[Tuple[int, int, int]]: num_experts = len(expert_list) num_device = self.total_device @@ -67,5 +71,3 @@ def get_target_device(self, expert_list: List[int]) -> List[Tuple[int, int, int] k += 1 return device_list - - diff --git a/moe_infinity/distributed/expert_executor.py b/moe_infinity/distributed/expert_executor.py index b2ba2b1..7ee80bc 100644 --- a/moe_infinity/distributed/expert_executor.py +++ b/moe_infinity/distributed/expert_executor.py @@ -3,10 +3,11 @@ # TorchMoE Team -import torch.distributed.rpc as rpc -import torch.distributed as dist -import torch import numpy as np +import torch +import torch.distributed as dist +import torch.distributed.rpc as rpc + from moe_infinity.utils import ArcherConfig @@ -17,7 +18,6 @@ def _call_expert_dispatcher(method, *args, **kwargs): class DistributedExpertExecutor: - def __init__(self, archer_config: ArcherConfig): self.archer_config = archer_config @@ -31,9 +31,16 @@ def set_device_map_manager(self, device_map_manager): def dispatch_local(self, hidden_states, router_mask, layer_id): num_expert = router_mask.shape[-1] - expert_count = torch.sum(router_mask.view((-1, num_expert)), dim=0).cpu().numpy().flatten() - - expert_list = np.arange(num_expert).astype(int)[expert_count > 0].tolist() + expert_count = ( + torch.sum(router_mask.view((-1, num_expert)), dim=0) + .cpu() + .numpy() + .flatten() + ) + + expert_list = ( + np.arange(num_expert).astype(int)[expert_count > 0].tolist() + ) expected_wait_cnt = len(expert_list) self.expert_dispatcher.set_inputs(hidden_states, router_mask) @@ -42,7 +49,9 @@ def dispatch_local(self, hidden_states, router_mask, layer_id): total_gpus = torch.cuda.device_count() for expert_id in expert_list: gpu_id = expert_id % total_gpus - self.expert_dispatcher.enqueue_expert(layer_id, expert_id, gpu_id, False) + self.expert_dispatcher.enqueue_expert( + layer_id, expert_id, gpu_id, False + ) result = self.expert_dispatcher.wait_expert() @@ -50,10 +59,16 @@ def dispatch_local(self, hidden_states, router_mask, layer_id): def dispatch(self, hidden_states, router_mask, layer_id): num_expert = router_mask.shape[-1] - expert_count = torch.sum(router_mask.view((-1, num_expert)), dim=0).cpu().numpy().flatten() + expert_count = ( + torch.sum(router_mask.view((-1, num_expert)), dim=0) + .cpu() + .numpy() + .flatten() + ) - expert_list = np.arange(num_expert).astype(int)[ - expert_count > 0].tolist() + expert_list = ( + np.arange(num_expert).astype(int)[expert_count > 0].tolist() + ) device_list = self.device_map_manager.get_target_device(expert_list) visited_ranks = set() @@ -66,15 +81,17 @@ def dispatch(self, hidden_states, router_mask, layer_id): futures = [] for rank in visited_ranks: if rank != dist.get_rank(): - future = rpc.rpc_async(f"worker_{rank}", - _call_expert_dispatcher, - args=("set_inputs", hidden_states.cpu(), - router_mask.cpu())) + future = rpc.rpc_async( + f"worker_{rank}", + _call_expert_dispatcher, + args=("set_inputs", hidden_states.cpu(), router_mask.cpu()), + ) futures.append(future) - future = rpc.rpc_async(f"worker_{rank}", - _call_expert_dispatcher, - args=("set_expected_queue", - rank_wait_cnt[rank])) + future = rpc.rpc_async( + f"worker_{rank}", + _call_expert_dispatcher, + args=("set_expected_queue", rank_wait_cnt[rank]), + ) futures.append(future) else: self.expert_dispatcher.set_inputs(hidden_states, router_mask) @@ -88,13 +105,15 @@ def dispatch(self, hidden_states, router_mask, layer_id): for k, device_meta in enumerate(device_list): rank, gpu_id, expert_id = device_meta if rank == dist.get_rank(): - self.expert_dispatcher.enqueue_expert(layer_id, expert_id, - gpu_id, False) + self.expert_dispatcher.enqueue_expert( + layer_id, expert_id, gpu_id, False + ) else: - future = rpc.rpc_async(f"worker_{rank}", - _call_expert_dispatcher, - args=("enqueue_expert", layer_id, - expert_id, gpu_id, True)) + future = rpc.rpc_async( + f"worker_{rank}", + _call_expert_dispatcher, + args=("enqueue_expert", layer_id, expert_id, gpu_id, True), + ) futures.append(future) # wait for all futures @@ -104,9 +123,11 @@ def dispatch(self, hidden_states, router_mask, layer_id): result_list = [] for rank in visited_ranks: if rank != dist.get_rank(): - result = rpc.rpc_sync(f"worker_{rank}", - _call_expert_dispatcher, - args=("wait_expert", )) + result = rpc.rpc_sync( + f"worker_{rank}", + _call_expert_dispatcher, + args=("wait_expert",), + ) result_list += result else: result = self.expert_dispatcher.wait_expert() diff --git a/moe_infinity/distributed/expert_prefetcher.py b/moe_infinity/distributed/expert_prefetcher.py index e989ae9..0bec6ea 100644 --- a/moe_infinity/distributed/expert_prefetcher.py +++ b/moe_infinity/distributed/expert_prefetcher.py @@ -3,13 +3,10 @@ # TorchMoE Team -import time -import numpy as np -import torch -from torch.distributed import rpc import torch.distributed as dist - +from torch.distributed import rpc from transformers import PretrainedConfig + from moe_infinity.utils import parse_moe_param @@ -24,8 +21,8 @@ class DistributedExpertPrefetcher(object): def __init__(self, config: PretrainedConfig): print(config) - self.num_layers, self.num_experts, self.num_encoder_layers = parse_moe_param( - config + self.num_layers, self.num_experts, self.num_encoder_layers = ( + parse_moe_param(config) ) def set_archer_engine(self, archer_engine): @@ -46,9 +43,11 @@ def prefetch_experts(self, layer_id, expert_matrix): for j in range(self.num_experts): if expert_matrix[i, j] > 0: expert_list.append( - (self.expert_tensor_map[(i,j)], expert_matrix[i, j]) + (self.expert_tensor_map[(i, j)], expert_matrix[i, j]) ) - ordered_expert_list = sorted(expert_list, key=lambda x: x[1], reverse=True) + ordered_expert_list = sorted( + expert_list, key=lambda x: x[1], reverse=True + ) tensor_ids = [x[0] for x in ordered_expert_list] device_list = self.device_map_manager.get_target_device(tensor_ids) diff --git a/moe_infinity/entrypoints/__init__.py b/moe_infinity/entrypoints/__init__.py index a700b88..6d22e0c 100644 --- a/moe_infinity/entrypoints/__init__.py +++ b/moe_infinity/entrypoints/__init__.py @@ -1 +1 @@ -from .big_modeling import MoE \ No newline at end of file +from .big_modeling import MoE diff --git a/moe_infinity/entrypoints/big_modeling.py b/moe_infinity/entrypoints/big_modeling.py index c9265a0..9f2da42 100644 --- a/moe_infinity/entrypoints/big_modeling.py +++ b/moe_infinity/entrypoints/big_modeling.py @@ -1,19 +1,18 @@ -from typing import Any, Union, Dict import os +from typing import Any, Dict, Union + import torch -import torch.nn as nn import transformers -from transformers import AutoConfig +from accelerate.utils.versions import is_torch_version from huggingface_hub import snapshot_download -from accelerate import init_empty_weights +from transformers import AutoConfig -from accelerate.utils.versions import is_torch_version +import moe_infinity from moe_infinity.common.constants import MODEL_MAPPING_NAMES -from moe_infinity.runtime import OffloadEngine -from moe_infinity.utils import get_checkpoint_paths, ArcherConfig from moe_infinity.models import apply_rotary_pos_emb -import moe_infinity from moe_infinity.models.modeling_arctic import ArcticConfig +from moe_infinity.runtime import OffloadEngine +from moe_infinity.utils import ArcherConfig, get_checkpoint_paths class MoE: @@ -58,7 +57,9 @@ def __init__( ) if config is None: - default_config_path = os.path.join(os.path.dirname(__file__), "config.json") + default_config_path = os.path.join( + os.path.dirname(__file__), "config.json" + ) if not os.path.exists(default_config_path): raise RuntimeError( "The `load_checkpoint_and_dispatch` function requires a configuration file. " @@ -66,9 +67,13 @@ def __init__( ) config = default_config_path if "arctic" in model_name_or_path: - model_config = ArcticConfig.from_pretrained(model_name_or_path, trust_remote_code=True) + model_config = ArcticConfig.from_pretrained( + model_name_or_path, trust_remote_code=True + ) else: - model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) + model_config = AutoConfig.from_pretrained( + model_name_or_path, trust_remote_code=True + ) architecture = model_config.architectures[0].lower() arch = None @@ -124,6 +129,7 @@ def __init__( "[WARNING] FlashAttention is not available in the current environment. Using default attention." ) pass + with self.engine.init(cls=model_cls, ar_config=config): self.model = model_cls.from_pretrained( model_name_or_path, @@ -133,22 +139,19 @@ def __init__( is_flash_attn_available=is_flash_attn_available, trust_remote_code=True, ) - + def _configure_hook(self, input_ids: torch.LongTensor): if self.arch == "mixtral": - transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb = ( - apply_rotary_pos_emb - ) + transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb = apply_rotary_pos_emb if self.arch == "grok": - moe_infinity.models.modeling_grok.modeling_grok1.apply_rotary_pos_emb = ( - apply_rotary_pos_emb - ) + moe_infinity.models.modeling_grok.modeling_grok1.apply_rotary_pos_emb = apply_rotary_pos_emb if self.arch == "arctic": - moe_infinity.models.modeling_arctic.modeling_arctic.apply_rotary_pos_emb = ( - apply_rotary_pos_emb - ) + moe_infinity.models.modeling_arctic.modeling_arctic.apply_rotary_pos_emb = apply_rotary_pos_emb + + if self.arch == "deepseek": + moe_infinity.models.modeling_deepseek.modeling_deepseek.apply_rotary_pos_emb = apply_rotary_pos_emb batch_size = input_ids.shape[0] self.seq_id_list = [ @@ -191,9 +194,9 @@ def forward(self, input_ids: torch.LongTensor, *args, **kwargs) -> Any: Returns: Any: The output of the model. """ - + self._configure_hook(input_ids) - + return self.model(input_ids, *args, **kwargs) def __call__(self, *args, **kwargs) -> Any: @@ -207,4 +210,4 @@ def __call__(self, *args, **kwargs) -> Any: Returns: Any: The output of the model. """ - return self.forward(*args, **kwargs) \ No newline at end of file + return self.forward(*args, **kwargs) diff --git a/moe_infinity/memory/__init__.py b/moe_infinity/memory/__init__.py index ccb5954..3315f3d 100644 --- a/moe_infinity/memory/__init__.py +++ b/moe_infinity/memory/__init__.py @@ -3,8 +3,8 @@ # TorchMoE Team -from .expert_tracer import ExpertTracer -from .expert_predictor import ExpertPredictor from .expert_cache import ExpertCache -from .expert_priority_score import * +from .expert_predictor import ExpertPredictor from .expert_prefetcher import ExpertPrefetcher +from .expert_priority_score import * +from .expert_tracer import ExpertTracer diff --git a/moe_infinity/memory/expert_cache.py b/moe_infinity/memory/expert_cache.py index 97b88d5..03d5237 100644 --- a/moe_infinity/memory/expert_cache.py +++ b/moe_infinity/memory/expert_cache.py @@ -1,11 +1,11 @@ +import logging import sys -from moe_infinity.memory.expert_priority_score import * -from moe_infinity.memory.expert_entry import ExpertCacheEntry +from collections import Counter import numpy as np -from collections import Counter -import logging +from moe_infinity.memory.expert_entry import ExpertCacheEntry +from moe_infinity.memory.expert_priority_score import * class ExpertCache: @@ -127,7 +127,7 @@ def gpu_evict(self, seq_id, layer_idx): else: assert False, "Should not reach here" - cache_candidates.sort(key=lambda x: x.r) # sort by r acending + cache_candidates.sort(key=lambda x: x.r) # sort by r ascending cache_candidates_in_gpu = [ x for x in cache_candidates @@ -139,9 +139,9 @@ def gpu_evict(self, seq_id, layer_idx): for candidate in cache_candidates: candidate_key = (candidate.expert_idx, candidate.layer_idx) if ( - not candidate_key in self.experts_protected_ondemand - and not candidate_key in self.experts_protected_prefetch - and not candidate_key in self.experts_protected_by_layer + candidate_key not in self.experts_protected_ondemand + and candidate_key not in self.experts_protected_prefetch + and candidate_key not in self.experts_protected_by_layer ): if candidate_key in self.gpu_expert_cache: del self.gpu_expert_cache[candidate_key] @@ -153,9 +153,11 @@ def gpu_evict(self, seq_id, layer_idx): candidate_key = (candidate.expert_idx, candidate.layer_idx) if candidate_key in self.gpu_expert_cache: if self.cache_policy == "priority": - if not candidate_key in self.experts_protected_ondemand: + if candidate_key not in self.experts_protected_ondemand: del self.gpu_expert_cache[candidate_key] - self.logger.debug(f"Force evicting expert {candidate_key}") + self.logger.debug( + f"Force evicting expert {candidate_key}" + ) return True else: del self.gpu_expert_cache[candidate_key] @@ -180,13 +182,13 @@ def cpu_evict(self, seq_id, layer_idx): layer_idx, self.tracer.num_layers, ) - cache_candidates.sort(key=lambda x: x.r) # sort by r acending + cache_candidates.sort(key=lambda x: x.r) # sort by r ascending for candidate in cache_candidates: candidate_key = (candidate.expert_idx, candidate.layer_idx) if ( - not candidate_key in self.experts_protected_ondemand - and not candidate_key in self.experts_protected_prefetch + candidate_key not in self.experts_protected_ondemand + and candidate_key not in self.experts_protected_prefetch ): if candidate_key in self.cpu_expert_cache: del self.cpu_expert_cache[candidate_key] @@ -263,7 +265,8 @@ def visit(self, expert_idx, layer_idx): def protect_experts_by_layer(self, layer_idx: int): self.experts_protected_by_layer = { - (expert_idx, layer_idx): ExpertCacheEntry(expert_idx, layer_idx) for expert_idx in range(self.tracer.num_experts) + (expert_idx, layer_idx): ExpertCacheEntry(expert_idx, layer_idx) + for expert_idx in range(self.tracer.num_experts) } def protect_experts_on_demand( @@ -279,7 +282,8 @@ def protect_experts_on_demand( ) self.experts_protected_ondemand = { - (entry.expert_idx, entry.layer_idx): entry for entry in cache_entries + (entry.expert_idx, entry.layer_idx): entry + for entry in cache_entries } def protect_experts_prefetch(self, matrix, layer_idx: int): @@ -290,7 +294,8 @@ def protect_experts_prefetch(self, matrix, layer_idx: int): cache_entries.append(ExpertCacheEntry(e, l, matrix[l, e])) self.experts_protected_prefetch = { - (entry.expert_idx, entry.layer_idx): entry for entry in cache_entries + (entry.expert_idx, entry.layer_idx): entry + for entry in cache_entries } diff --git a/moe_infinity/memory/expert_entry.py b/moe_infinity/memory/expert_entry.py index 9fe9c61..3b48bc8 100644 --- a/moe_infinity/memory/expert_entry.py +++ b/moe_infinity/memory/expert_entry.py @@ -1,6 +1,7 @@ from dataclasses import dataclass + import numpy as np -import hashlib + @dataclass class ExpertTraceEntry: @@ -11,7 +12,7 @@ class ExpertTraceEntry: def __hash__(self): return hash(self.seq_id) - + @dataclass class ExpertCacheEntry: @@ -22,4 +23,4 @@ class ExpertCacheEntry: timestamp: int = 0 def __hash__(self): - return hash((self.layer_idx, self.expert_idx)) \ No newline at end of file + return hash((self.layer_idx, self.expert_idx)) diff --git a/moe_infinity/memory/expert_predictor.py b/moe_infinity/memory/expert_predictor.py index 2637818..d4d7860 100644 --- a/moe_infinity/memory/expert_predictor.py +++ b/moe_infinity/memory/expert_predictor.py @@ -1,14 +1,15 @@ -import time -from moe_infinity.memory.expert_tracer import ExpertTracer -from moe_infinity.memory.expert_entry import ExpertCacheEntry -import copy from transformers import PretrainedConfig + +from moe_infinity.memory.expert_tracer import ExpertTracer from moe_infinity.utils import parse_moe_param + class ExpertPredictor: def __init__(self, config: PretrainedConfig) -> None: - self.num_layers, self.num_experts, self.num_encoder_layers = parse_moe_param(config) - self.layer_decay_func = lambda x, l, L: -1 / (L+1) * (x-l) + 1 + self.num_layers, self.num_experts, self.num_encoder_layers = ( + parse_moe_param(config) + ) + self.layer_decay_func = lambda x, l, L: -1 / (L + 1) * (x - l) + 1 def add_tracer(self, tracer: ExpertTracer): self.tracer = tracer @@ -18,15 +19,17 @@ def predict(self, seq_id, expert_list, layer_idx): current_entry = self.tracer.get_entry(seq_id) # start_time = time.time() - expert_matrix = self.tracer.find_most_similar(current_entry.matrix, layer_idx) + expert_matrix = self.tracer.find_most_similar( + current_entry.matrix, layer_idx + ) # print("find_most_similar", time.time() - start_time) # expert_matrix = copy.deepcopy(entry) expert_matrix[:layer_idx, :] = 0 for l in range(layer_idx, self.num_layers): - expert_matrix[l] = (expert_matrix[l] + 1e-8) * self.layer_decay_func(l, layer_idx, self.num_layers) + expert_matrix[l] = ( + expert_matrix[l] + 1e-8 + ) * self.layer_decay_func(l, layer_idx, self.num_layers) return expert_matrix - - \ No newline at end of file diff --git a/moe_infinity/memory/expert_prefetcher.py b/moe_infinity/memory/expert_prefetcher.py index 45fe69c..02c5fc5 100644 --- a/moe_infinity/memory/expert_prefetcher.py +++ b/moe_infinity/memory/expert_prefetcher.py @@ -3,22 +3,20 @@ # TorchMoE Team -import time -import numpy as np -import torch -from torch.distributed import rpc -import torch.distributed as dist from transformers import PretrainedConfig + from moe_infinity.utils import parse_moe_param + class ExpertPrefetcher(object): cache_file_rd = None + first_k_dense_replace: int = 0 def __init__(self, config: PretrainedConfig): print(config) - self.num_layers, self.num_experts, self.num_encoder_layers = parse_moe_param( - config + self.num_layers, self.num_experts, self.num_encoder_layers = ( + parse_moe_param(config) ) def set_archer_engine(self, archer_engine): @@ -33,11 +31,13 @@ def prefetch_experts(self, layer_id, expert_matrix): for j in range(self.num_experts): if expert_matrix[i, j] > 0: expert_list.append( - (self.expert_tensor_map[(i,j)], expert_matrix[i, j]) + (self.expert_tensor_map[(i, j)], expert_matrix[i, j]) ) - ordered_expert_list = sorted(expert_list, key=lambda x: x[1], reverse=True) + ordered_expert_list = sorted( + expert_list, key=lambda x: x[1], reverse=True + ) tensor_ids = [x[0] for x in ordered_expert_list] - + self.archer_engine.replace_cache_candidates(tensor_ids) for tensor_id in tensor_ids: gpu_id = self.archer_engine.get_node_default_device([tensor_id]) diff --git a/moe_infinity/memory/expert_priority_score.py b/moe_infinity/memory/expert_priority_score.py index 6b4e56f..b3eb9db 100644 --- a/moe_infinity/memory/expert_priority_score.py +++ b/moe_infinity/memory/expert_priority_score.py @@ -1,14 +1,15 @@ -from collections import Counter -import numpy as np import copy -from typing import Dict, Set, List, Tuple +from typing import Set + +import numpy as np from moe_infinity.memory.expert_entry import ExpertCacheEntry, ExpertTraceEntry decay_from_first = lambda x, L: -1 / L * x + 1 -decay_from_last = lambda x, L: 1 / (L+1) * x +decay_from_last = lambda x, L: 1 / (L + 1) * x + +layer_decay = lambda x, l: (x + 1) / np.abs(l - x + 1) -layer_decay = lambda x, l: (x+1) / np.abs(l-x + 1) def convert_score_matrix_to_list(score_matrix: np.ndarray): score_list = [] @@ -18,21 +19,35 @@ def convert_score_matrix_to_list(score_matrix: np.ndarray): score_list.append(ExpertCacheEntry(expert_idx, layer_idx, r)) return score_list + def lru_score(cache_entries: Set[ExpertCacheEntry]): lru_score = [] for entry in cache_entries: - lru_score.append(ExpertCacheEntry(entry.expert_idx, entry.layer_idx, entry.timestamp)) + lru_score.append( + ExpertCacheEntry(entry.expert_idx, entry.layer_idx, entry.timestamp) + ) return lru_score + def lru_score_with_layers(cache_entries: Set[ExpertCacheEntry], current_layer): lru_score = [] for entry in cache_entries: - if entry.layer_idx >= current_layer and entry.layer_idx < current_layer + 3: - lru_score.append(ExpertCacheEntry(entry.expert_idx, entry.layer_idx, 1e10)) + if ( + entry.layer_idx >= current_layer + and entry.layer_idx < current_layer + 3 + ): + lru_score.append( + ExpertCacheEntry(entry.expert_idx, entry.layer_idx, 1e10) + ) else: - lru_score.append(ExpertCacheEntry(entry.expert_idx, entry.layer_idx, entry.timestamp)) + lru_score.append( + ExpertCacheEntry( + entry.expert_idx, entry.layer_idx, entry.timestamp + ) + ) return lru_score + def lfu_score(expert_freq: dict): # convert to list of tuples sum = 0 @@ -41,14 +56,15 @@ def lfu_score(expert_freq: dict): if sum == 0: sum = 1 - + lfu_score = [] for key, value in expert_freq.items(): expert_idx, layer_idx = key lfu_score.append(ExpertCacheEntry(expert_idx, layer_idx, value / sum)) - + return lfu_score + def oracle_score(expert_freq: dict, decoder_entry: ExpertTraceEntry): frequency_score = np.zeros_like(decoder_entry.matrix) frequency_sum = 0 @@ -65,7 +81,15 @@ def oracle_score(expert_freq: dict, decoder_entry: ExpertTraceEntry): return convert_score_matrix_to_list(frequency_score) -def priority_score(expert_freq: dict, cache_entries: Set[ExpertCacheEntry], trace_entries: Set[ExpertTraceEntry], decoder_entry: ExpertTraceEntry, current_layer, total_layer): + +def priority_score( + expert_freq: dict, + cache_entries: Set[ExpertCacheEntry], + trace_entries: Set[ExpertTraceEntry], + decoder_entry: ExpertTraceEntry, + current_layer, + total_layer, +): num_encoder_layers = total_layer // 2 # print("Cache entries size", len(cache_entries)) frequency_score = np.zeros_like(decoder_entry.matrix) @@ -76,12 +100,14 @@ def priority_score(expert_freq: dict, cache_entries: Set[ExpertCacheEntry], trac if np.sum(frequency_score[num_encoder_layers:]) == 0: frequency_score[num_encoder_layers:] = 1 - + if np.sum(frequency_score[:num_encoder_layers]) == 0: frequency_score[:num_encoder_layers] = 1 frequency_score = frequency_score / np.sum(frequency_score) + 1e-6 - assert np.sum(frequency_score) > 0, f"frequency_score = {frequency_score}, frequency_sum = {frequency_sum}" + assert ( + np.sum(frequency_score) > 0 + ), f"frequency_score = {frequency_score}, frequency_sum = {frequency_sum}" # print("frequency_score", np.sum(frequency_score), np.max(frequency_score), np.min(frequency_score), frequency_score.shape) topo_expert_score = np.zeros_like(decoder_entry.matrix) @@ -89,14 +115,26 @@ def priority_score(expert_freq: dict, cache_entries: Set[ExpertCacheEntry], trac entry_layer_idx = i if current_layer < num_encoder_layers: if i < num_encoder_layers: - topo_expert_score[i, :] = decay_from_first(i, num_encoder_layers) if i > current_layer else 1.0 + topo_expert_score[i, :] = ( + decay_from_first(i, num_encoder_layers) + if i > current_layer + else 1.0 + ) else: - topo_expert_score[i, :] = decay_from_last(i-num_encoder_layers, num_encoder_layers) + topo_expert_score[i, :] = decay_from_last( + i - num_encoder_layers, num_encoder_layers + ) else: if i < num_encoder_layers: - topo_expert_score[i, :] = decay_from_first(i, num_encoder_layers) + topo_expert_score[i, :] = decay_from_first( + i, num_encoder_layers + ) else: - topo_expert_score[i, :] = decay_from_last(i-num_encoder_layers, num_encoder_layers) if i > current_layer else 1.0 + topo_expert_score[i, :] = ( + decay_from_last(i - num_encoder_layers, num_encoder_layers) + if i > current_layer + else 1.0 + ) topo_expert_score = topo_expert_score / np.sum(topo_expert_score) + 1e-6 seq_expert_score = np.zeros_like(decoder_entry.matrix) @@ -104,13 +142,13 @@ def priority_score(expert_freq: dict, cache_entries: Set[ExpertCacheEntry], trac # zero_access = False for entry in trace_entries: freq_sum += entry.access - + if freq_sum == 0: # freq_sum = len(trace_entries) seq_expert_score = np.ones_like(seq_expert_score) freq_sum = 1 zero_access = True - else: + else: for entry in trace_entries: matrix = copy.deepcopy(entry.matrix) matrix[matrix > 0] = 1 @@ -127,10 +165,10 @@ def priority_score(expert_freq: dict, cache_entries: Set[ExpertCacheEntry], trac for i in range(decoder_matrix.shape[0]): if np.sum(decoder_matrix[i, :]) == 0: decoder_matrix[i, :] = 1 - decoder_matrix[i, :] = decoder_matrix[i, :] / np.sum(decoder_matrix[i, :]) + decoder_matrix[i, :] = decoder_matrix[i, :] / np.sum( + decoder_matrix[i, :] + ) decoder_matrix = decoder_matrix / np.sum(decoder_matrix) + 1e-6 total_score = topo_expert_score * decoder_matrix * frequency_score - - return convert_score_matrix_to_list(total_score) - + return convert_score_matrix_to_list(total_score) diff --git a/moe_infinity/memory/expert_tracer.py b/moe_infinity/memory/expert_tracer.py index a56e4c7..9d275e3 100644 --- a/moe_infinity/memory/expert_tracer.py +++ b/moe_infinity/memory/expert_tracer.py @@ -1,17 +1,15 @@ import copy import os -import time -from typing import Union -import numpy as np import uuid from collections import Counter -from scipy.spatial.distance import cosine +from typing import Union + +import numpy as np import torch import torch.nn as nn from transformers import PretrainedConfig # from sklearn.metrics.pairwise import cosine_similarity - from moe_infinity.memory.expert_entry import ExpertTraceEntry from moe_infinity.utils import parse_moe_param @@ -23,31 +21,36 @@ def __new__(cls, *args, **kwargs): if cls._instance is None: cls._instance = super(ExpertTracer, cls).__new__(cls) return cls._instance - - def __init__(self, capacity: int, config:PretrainedConfig): - self.num_layers, self.num_experts, self.num_encoder_layers = parse_moe_param(config) + + def __init__(self, capacity: int, config: PretrainedConfig): + self.num_layers, self.num_experts, self.num_encoder_layers = ( + parse_moe_param(config) + ) self.capacity = capacity self.trace = {} - self.trace_collection = torch.zeros((capacity, self.num_layers, self.num_experts), device="cuda:0") + self.trace_collection = torch.zeros( + (capacity, self.num_layers, self.num_experts), device="cuda:0" + ) self.collection_access = np.zeros((capacity,)) self.cos = nn.CosineSimilarity(dim=2, eps=1e-6) def load_trace(self, trace: Union[os.PathLike, np.ndarray]): if isinstance(trace, os.PathLike): - self.trace_collection = torch.from_numpy(np.load(trace, allow_pickle=False)) + self.trace_collection = torch.from_numpy( + np.load(trace, allow_pickle=False) + ) elif isinstance(trace, np.ndarray): self.trace_collection = trace - + self.persistent_capacity = self.trace_collection.shape[0] assert self.persistent_capacity <= self.capacity, ( f"loaded trace capacity {self.persistent_capacity} must be " f"less than or equal to capacity in config {self.capacity}" ) - def create_entry(self): seq_id = uuid.uuid4().hex self.trace[seq_id] = ExpertTraceEntry( @@ -94,7 +97,9 @@ def find_most_similar(self, matrix, layer_idx) -> np.ndarray: trace_collection_copy[:, : (layer_idx + 1), :] = 1e-9 # print("trace_collection copy", time.time() - start_time) - trace_collection_copy /= torch.sum(trace_collection_copy, dim=2, keepdims=True) + trace_collection_copy /= torch.sum( + trace_collection_copy, dim=2, keepdims=True + ) matrix_copy = torch.from_numpy(matrix.copy()).to("cuda:0") matrix_copy /= torch.sum(matrix_copy, dim=1, keepdims=True) @@ -118,4 +123,3 @@ def find_most_similar(self, matrix, layer_idx) -> np.ndarray: entry = self.trace_collection[min_idx].to("cpu").numpy() return entry - diff --git a/moe_infinity/models/__init__.py b/moe_infinity/models/__init__.py index 6ad9507..19e3fe3 100644 --- a/moe_infinity/models/__init__.py +++ b/moe_infinity/models/__init__.py @@ -3,9 +3,10 @@ # TorchMoE Team -from .switch_transformers import SyncSwitchTransformersSparseMLP -from .nllb_moe import SyncNllbMoeSparseMLP -from .mixtral import SyncMixtralSparseMoeBlock +from .arctic import ArcticConfig, SyncArcticMoeBlock +from .deepseek import DeepseekV2MoEBlock from .grok import SyncGrokMoeBlock -from .arctic import SyncArcticMoeBlock, ArcticConfig +from .mixtral import SyncMixtralSparseMoeBlock from .model_utils import apply_rotary_pos_emb +from .nllb_moe import SyncNllbMoeSparseMLP +from .switch_transformers import SyncSwitchTransformersSparseMLP diff --git a/moe_infinity/models/arctic.py b/moe_infinity/models/arctic.py index 7805c82..e14e198 100644 --- a/moe_infinity/models/arctic.py +++ b/moe_infinity/models/arctic.py @@ -1,34 +1,36 @@ -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple + import torch -import torch.nn.functional as F import torch.nn as nn -from .modeling_arctic import ArcticConfig -from .modeling_arctic import ArcticMLP +import torch.nn.functional as F from moe_infinity.utils import ArcherConfig -from .model_utils import apply_rotary_pos_emb + +from .modeling_arctic import ArcticConfig, ArcticMLP + class SyncArcticMoeBlock(nn.Module): archer_config: ArcherConfig = None layer_id: int = None - + def __init__(self, config: ArcticConfig, layer_id: int, **kwargs): super().__init__() self.hidden_dim = config.hidden_size self.num_experts = config.num_local_experts - self.layer_id = layer_id + self.layer_id = layer_id self.top_k = config.num_experts_per_tok - self.is_moe_layer = (layer_id+1) % config.moe_layer_frequency == 0 + self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0 self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.experts = nn.ModuleList([ArcticMLP(config) for i in range(self.num_experts)]) + self.experts = nn.ModuleList( + [ArcticMLP(config) for i in range(self.num_experts)] + ) self.archer_tracer = None self.archer_engine = None self.expert_tensor_ids: Dict[int, int] = None - - + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -42,11 +44,17 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) - expert_index = selected_experts.reshape(batch_size, sequence_length, self.top_k) + expert_index = selected_experts.reshape( + batch_size, sequence_length, self.top_k + ) for i in range(batch_size): seq_id = self.seq_id_list[i] - expert_matrix = self.expert_predictor.predict(seq_id, expert_index[i], self.layer_id) - self.expert_prefetcher.prefetch_experts(self.layer_id, expert_matrix) + expert_matrix = self.expert_predictor.predict( + seq_id, expert_index[i], self.layer_id + ) + self.expert_prefetcher.prefetch_experts( + self.layer_id, expert_matrix + ) final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), @@ -58,4 +66,4 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: expert_mask = torch.nn.functional.one_hot( selected_experts, num_classes=self.num_experts ).permute(2, 1, 0) - return final_hidden_states, expert_mask \ No newline at end of file + return final_hidden_states, expert_mask diff --git a/moe_infinity/models/deepseek.py b/moe_infinity/models/deepseek.py new file mode 100644 index 0000000..6eb5061 --- /dev/null +++ b/moe_infinity/models/deepseek.py @@ -0,0 +1,90 @@ +from typing import Dict, Optional, Tuple +import torch +import torch.nn.functional as F +import torch.nn as nn + +from .modeling_deepseek import DeepseekV2MLP, MoEGate + + +class DeepseekV2MoEBlock(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + self.experts = nn.ModuleList( + [ + DeepseekV2MLP( + config, intermediate_size=config.moe_intermediate_size + ) + for i in range(config.n_routed_experts) + ] + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV2MLP( + config=config, intermediate_size=intermediate_size + ) + + + self.archer_tracer = None + self.archer_engine = None + self.expert_tensor_ids: Dict[int, int] = None + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + cnts = topk_idx.new_zeros((topk_idx.shape[0], len(self.experts))) + cnts.scatter_(1, topk_idx, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_idx.view(-1).argsort() + sorted_tokens = hidden_states[idxs // topk_idx.shape[1]] + + tokens_per_expert = tokens_per_expert.cpu().numpy() + + batch_size, sequence_length, _ = orig_shape + router_mask = F.one_hot(topk_idx, num_classes=self.config.n_routed_experts) + + # print("router_mask", router_mask.shape) + + expert_index = topk_idx.reshape(batch_size, sequence_length, self.config.num_experts_per_tok) + for i in range(batch_size): + seq_id = self.seq_id_list[i] + expert_matrix = self.expert_predictor.predict(seq_id, expert_index[i], self.layer_id) + self.expert_prefetcher.prefetch_experts(self.layer_id, expert_matrix) + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out.to(hidden_states.device)) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + y = ( + new_x.view(*topk_idx.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y \ No newline at end of file diff --git a/moe_infinity/models/grok.py b/moe_infinity/models/grok.py index d8847a6..695ca5c 100644 --- a/moe_infinity/models/grok.py +++ b/moe_infinity/models/grok.py @@ -1,35 +1,36 @@ -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple + import torch -import torch.nn.functional as F import torch.nn as nn -from .modeling_grok import Grok1Config -from .modeling_grok import MoeBlock, MoeMLP - +import torch.nn.functional as F from moe_infinity.utils import ArcherConfig -from .model_utils import apply_rotary_pos_emb +from .modeling_grok import MoeMLP class SyncGrokMoeBlock(nn.Module): archer_config: ArcherConfig = None layer_id: int = None - - def __init__(self, hidden_dim: int, ffn_dim: int, num_experts: int, top_k: int): + + def __init__( + self, hidden_dim: int, ffn_dim: int, num_experts: int, top_k: int + ): super().__init__() self.hidden_dim = hidden_dim self.ffn_dim = ffn_dim self.num_experts = num_experts self.top_k = top_k # gating - self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.experts = nn.ModuleList([MoeMLP(hidden_dim, ffn_dim) for _ in range(self.num_experts)]) + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + self.experts = nn.ModuleList( + [MoeMLP(hidden_dim, ffn_dim) for _ in range(self.num_experts)] + ) self.archer_tracer = None self.archer_engine = None self.expert_tensor_ids: Dict[int, int] = None - - + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) @@ -44,19 +45,27 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: routing_weights = routing_weights.to(hidden_states.dtype) router_mask = F.one_hot(selected_experts, num_classes=self.num_experts) - routing_weights_mask = (routing_weights[:, :, None] * router_mask).permute( - 0, 2, 1 - ) + routing_weights_mask = ( + routing_weights[:, :, None] * router_mask + ).permute(0, 2, 1) router_mask = router_mask.permute(0, 2, 1) # assume top-2 here - router_mask = torch.logical_or(router_mask[:, :, 0], router_mask[:, :, 1]) + router_mask = torch.logical_or( + router_mask[:, :, 0], router_mask[:, :, 1] + ) routing_weights_mask = torch.sum(routing_weights_mask, dim=-1) - expert_index = selected_experts.reshape(batch_size, sequence_length, self.top_k) + expert_index = selected_experts.reshape( + batch_size, sequence_length, self.top_k + ) for i in range(batch_size): seq_id = self.seq_id_list[i] - expert_matrix = self.expert_predictor.predict(seq_id, expert_index[i], self.layer_id) - self.expert_prefetcher.prefetch_experts(self.layer_id, expert_matrix) + expert_matrix = self.expert_predictor.predict( + seq_id, expert_index[i], self.layer_id + ) + self.expert_prefetcher.prefetch_experts( + self.layer_id, expert_matrix + ) final_hidden_states = torch.zeros( (batch_size * sequence_length, hidden_dim), @@ -64,10 +73,15 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: device=hidden_states.device, ) - results = self.expert_executor.dispatch_local(hidden_states, router_mask, self.layer_id) + results = self.expert_executor.dispatch_local( + hidden_states, router_mask, self.layer_id + ) for output, _, idx, _ in results: token_indices = router_mask[:, idx].bool() - final_hidden_states[token_indices, :] += output.to(routing_weights_mask.device) * routing_weights_mask[token_indices, idx][:, None] + final_hidden_states[token_indices, :] += ( + output.to(routing_weights_mask.device) + * routing_weights_mask[token_indices, idx][:, None] + ) # # One hot encode the selected experts to create an expert mask # # this will be used to easily index which expert is going to be sollicitated @@ -102,8 +116,8 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: # final_hidden_states.index_add_( # 0, top_x, current_hidden_states.to(hidden_states.dtype) # ) - + final_hidden_states = final_hidden_states.reshape( batch_size, sequence_length, hidden_dim ) - return final_hidden_states, router_logits \ No newline at end of file + return final_hidden_states, router_logits diff --git a/moe_infinity/models/mixtral.py b/moe_infinity/models/mixtral.py index c6de55f..8f717b9 100644 --- a/moe_infinity/models/mixtral.py +++ b/moe_infinity/models/mixtral.py @@ -3,19 +3,18 @@ # TorchMoE Team -import time -from typing import Dict, Optional +from typing import Dict + import torch -import torch.nn.functional as F import torch.nn as nn -import transformers -from transformers import MixtralConfig +import torch.nn.functional as F from transformers.models.mixtral.modeling_mixtral import ( MixtralBlockSparseTop2MLP, ) from moe_infinity.utils import ArcherConfig + class SyncMixtralSparseMoeBlock(nn.Module): archer_config: ArcherConfig = None layer_id: int = None @@ -30,7 +29,9 @@ def __init__(self, config): # gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + self.experts = nn.ModuleList( + [MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)] + ) self.archer_tracer = None self.archer_engine = None @@ -45,40 +46,56 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) - router_mask = F.one_hot(selected_experts, num_classes=self.num_experts) - routing_weights_mask = (routing_weights[:, :, None] * router_mask).permute( - 0, 2, 1 - ) + routing_weights_mask = ( + routing_weights[:, :, None] * router_mask + ).permute(0, 2, 1) router_mask = router_mask.permute(0, 2, 1) # assume top-2 here - router_mask = torch.logical_or(router_mask[:, :, 0], router_mask[:, :, 1]) + router_mask = torch.logical_or( + router_mask[:, :, 0], router_mask[:, :, 1] + ) routing_weights_mask = torch.sum(routing_weights_mask, dim=-1) # print("selected_experts", selected_experts) - expert_index = selected_experts.reshape(batch_size, sequence_length, self.top_k) + expert_index = selected_experts.reshape( + batch_size, sequence_length, self.top_k + ) for i in range(batch_size): seq_id = self.seq_id_list[i] # start_time = time.time() - expert_matrix = self.expert_predictor.predict(seq_id, expert_index[i], self.layer_id) + expert_matrix = self.expert_predictor.predict( + seq_id, expert_index[i], self.layer_id + ) # print("predict", time.time() - start_time) # start_time = time.time() - self.expert_prefetcher.prefetch_experts(self.layer_id, expert_matrix) + self.expert_prefetcher.prefetch_experts( + self.layer_id, expert_matrix + ) # print("prefetch", time.time() - start_time) final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, ) - results = self.expert_executor.dispatch_local(hidden_states, router_mask, self.layer_id) + results = self.expert_executor.dispatch_local( + hidden_states, router_mask, self.layer_id + ) for output, _, idx, _ in results: token_indices = router_mask[:, idx].bool() - final_hidden_states[token_indices, :] += output.to(routing_weights_mask.device) * routing_weights_mask[token_indices, idx][:, None] + final_hidden_states[token_indices, :] += ( + output.to(routing_weights_mask.device) + * routing_weights_mask[token_indices, idx][:, None] + ) # for expert_idx in range(self.num_experts): # # expert_layer = self.experts[expert_idx] @@ -92,7 +109,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # ) # final_hidden_states[token_indices, :] += current_hidden_states - - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) return final_hidden_states, router_logits - diff --git a/moe_infinity/models/model_utils.py b/moe_infinity/models/model_utils.py index 38a97c4..697159c 100644 --- a/moe_infinity/models/model_utils.py +++ b/moe_infinity/models/model_utils.py @@ -1,11 +1,13 @@ import torch + def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] x2 = x[..., x.shape[-1] // 2 :] return torch.cat((-x2, x1), dim=-1) + def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): device = position_ids.device position_ids = position_ids.to(cos.device) @@ -15,4 +17,4 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) position_ids = position_ids.to(device) - return q_embed, k_embed \ No newline at end of file + return q_embed, k_embed diff --git a/moe_infinity/models/modeling_arctic/__init__.py b/moe_infinity/models/modeling_arctic/__init__.py index 8dd0c0d..f6d0b04 100644 --- a/moe_infinity/models/modeling_arctic/__init__.py +++ b/moe_infinity/models/modeling_arctic/__init__.py @@ -1,3 +1,8 @@ -from .modeling_arctic import ArcticForCausalLM, apply_rotary_pos_emb, ArcticMLP, ArcticMoE from .configuration_arctic import ArcticConfig -from .tokenization_arctic import ArcticTokenizer \ No newline at end of file +from .modeling_arctic import ( + ArcticForCausalLM, + ArcticMLP, + ArcticMoE, + apply_rotary_pos_emb, +) +from .tokenization_arctic import ArcticTokenizer diff --git a/moe_infinity/models/modeling_arctic/configuration_arctic.py b/moe_infinity/models/modeling_arctic/configuration_arctic.py index 2962478..fee0007 100644 --- a/moe_infinity/models/modeling_arctic/configuration_arctic.py +++ b/moe_infinity/models/modeling_arctic/configuration_arctic.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" Arctic model configuration""" +"""Arctic model configuration""" from dataclasses import asdict, dataclass from typing import Any, Dict @@ -19,7 +19,6 @@ from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging - logger = logging.get_logger(__name__) ARCTIC_PRETRAINED_CONFIG_ARCHIVE_MAP = { @@ -198,7 +197,9 @@ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "ArcticConfig": else: config = result if isinstance(config.quantization, dict): - config.quantization = ArcticQuantizationConfig(**config.quantization) + config.quantization = ArcticQuantizationConfig( + **config.quantization + ) return result def to_dict(self) -> Dict[str, Any]: diff --git a/moe_infinity/models/modeling_arctic/modeling_arctic.py b/moe_infinity/models/modeling_arctic/modeling_arctic.py index 427fb78..563a6f8 100644 --- a/moe_infinity/models/modeling_arctic/modeling_arctic.py +++ b/moe_infinity/models/modeling_arctic/modeling_arctic.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -17,13 +16,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch Arctic model.""" +"""PyTorch Arctic model.""" + import copy import inspect -import time import math -import warnings import re +import warnings from typing import List, Optional, Tuple, Union import torch @@ -31,9 +30,9 @@ import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache +from transformers.integrations.deepspeed import is_deepspeed_available from transformers.modeling_attn_mask_utils import ( _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, @@ -54,12 +53,12 @@ replace_return_docstrings, ) from transformers.utils.import_utils import is_torch_fx_available + from .configuration_arctic import ArcticConfig -from transformers.integrations.deepspeed import is_deepspeed_available -from transformers.utils.versions import require_version if is_deepspeed_available(): - from deepspeed.moe.layer import MoE + from deepspeed.moe.layer import MoE + # Note that below will crash if there is an available deepspeed that does not have ds_linear. try: import deepspeed.linear as ds_linear @@ -70,9 +69,15 @@ if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + from flash_attn.bert_padding import ( # noqa + index_first_axis, + pad_input, + unpad_input, + ) - _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + _flash_supports_window_size = "window_size" in list( + inspect.signature(flash_attn_func).parameters + ) # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. # It means that the function will not be traced through and simply appear as a node in the graph. @@ -80,7 +85,9 @@ if not is_torch_greater_or_equal_than_1_13: import torch.fx - _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + _prepare_4d_causal_attention_mask = torch.fx.wrap( + _prepare_4d_causal_attention_mask + ) logger = logging.get_logger(__name__) @@ -100,11 +107,15 @@ # if raise_error: # raise ValueError(f"DeepSpeed is required for this feature, {error_msg}") # else: - + # return available_and_valid + def load_balancing_loss_func( - gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=4, attention_mask: Optional[torch.Tensor] = None + gate_logits: torch.Tensor, + num_experts: torch.Tensor = None, + top_k=4, + attention_mask: Optional[torch.Tensor] = None, ) -> float: r""" Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. @@ -128,9 +139,13 @@ def load_balancing_loss_func( if isinstance(gate_logits, tuple): compute_device = gate_logits[0].device - concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + concatenated_gate_logits = torch.cat( + [layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0 + ) - routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + routing_weights = torch.nn.functional.softmax( + concatenated_gate_logits, dim=-1 + ) _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) @@ -144,35 +159,43 @@ def load_balancing_loss_func( router_prob_per_expert = torch.mean(routing_weights, dim=0) else: batch_size, sequence_length = attention_mask.shape - num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + num_hidden_layers = concatenated_gate_logits.shape[0] // ( + batch_size * sequence_length + ) # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask expert_attention_mask = ( attention_mask[None, :, :, None, None] - .expand((num_hidden_layers, batch_size, sequence_length, 2, num_experts)) + .expand( + (num_hidden_layers, batch_size, sequence_length, 2, num_experts) + ) .reshape(-1, 2, num_experts) .to(compute_device) ) # Compute the percentage of tokens routed to each experts - tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( - expert_attention_mask, dim=0 - ) + tokens_per_expert = torch.sum( + expert_mask.float() * expert_attention_mask, dim=0 + ) / torch.sum(expert_attention_mask, dim=0) # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert router_per_expert_attention_mask = ( attention_mask[None, :, :, None] - .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .expand( + (num_hidden_layers, batch_size, sequence_length, num_experts) + ) .reshape(-1, num_experts) .to(compute_device) ) # Compute the average probability of routing to these experts - router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( - router_per_expert_attention_mask, dim=0 - ) + router_prob_per_expert = torch.sum( + routing_weights * router_per_expert_attention_mask, dim=0 + ) / torch.sum(router_per_expert_attention_mask, dim=0) - overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + overall_loss = torch.sum( + tokens_per_expert * router_prob_per_expert.unsqueeze(0) + ) return overall_loss * num_experts @@ -181,7 +204,9 @@ def _get_unpad_data(attention_mask): seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() - cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) return ( indices, cu_seqlens, @@ -203,40 +228,57 @@ def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) return self.weight * hidden_states.to(input_dtype) # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Arctic class ArcticRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, dim, max_position_embeddings=2048, base=10000, device=None + ): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base - inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + inv_freq = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) self.register_buffer("inv_freq", inv_freq, persistent=False) # Build here to make `torch.jit.trace` work. self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), ) def _set_cos_sin_cache(self, seq_len, device, dtype): self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + self.register_buffer( + "cos_cached", emb.cos().to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin().to(dtype), persistent=False + ) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + self._set_cos_sin_cache( + seq_len=seq_len, device=x.device, dtype=x.dtype + ) return ( self.cos_cached[:seq_len].to(dtype=x.dtype), @@ -289,8 +331,12 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: batch, num_key_value_heads, slen, head_dim = hidden_states.shape if n_rep == 1: return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape( + batch, num_key_value_heads * n_rep, slen, head_dim + ) # Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Arctic @@ -300,7 +346,9 @@ class ArcticAttention(nn.Module): and "Generating Long Sequences with Sparse Transformers". """ - def __init__(self, config: ArcticConfig, layer_idx: Optional[int] = None, **kwargs): + def __init__( + self, config: ArcticConfig, layer_idx: Optional[int] = None, **kwargs + ): super().__init__() self.config = config self.layer_idx = layer_idx @@ -320,7 +368,9 @@ def __init__(self, config: ArcticConfig, layer_idx: Optional[int] = None, **kwa self.rope_theta = config.rope_theta self.is_causal = True self.attention_dropout = config.attention_dropout - self.use_deepspeed_implementation = USE_DEEPSPEED_MOE_ARG in kwargs and kwargs[USE_DEEPSPEED_MOE_ARG] + self.use_deepspeed_implementation = ( + USE_DEEPSPEED_MOE_ARG in kwargs and kwargs[USE_DEEPSPEED_MOE_ARG] + ) if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" @@ -331,31 +381,47 @@ def __init__(self, config: ArcticConfig, layer_idx: Optional[int] = None, **kwa deepspeed_lora_config = kwargs.get(DEEPSPEED_LORA_CONFIG) quantization_config = kwargs.get(QUANTIZATION_CONFIG, None) - self.q_proj = get_arctic_linear(self.hidden_size, self.num_heads * self.head_dim, bias=False, - use_deepspeed_implementation=self.use_deepspeed_implementation, - ds_optimized_lora_config=deepspeed_lora_config, - ds_optimized_quantization_config=quantization_config, - ds_optimized_base_weight_sharding=True, - dtype=torch.bfloat16) - self.k_proj = get_arctic_linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, - use_deepspeed_implementation=self.use_deepspeed_implementation, - ds_optimized_lora_config=deepspeed_lora_config, - ds_optimized_quantization_config=quantization_config, - ds_optimized_base_weight_sharding=True, - dtype=torch.bfloat16) - self.v_proj = get_arctic_linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False, - use_deepspeed_implementation=self.use_deepspeed_implementation, - ds_optimized_lora_config=deepspeed_lora_config, - ds_optimized_quantization_config=quantization_config, - ds_optimized_base_weight_sharding=True, - dtype=torch.bfloat16) - self.o_proj = get_arctic_linear(self.hidden_size, self.hidden_size, bias=False, - use_deepspeed_implementation=self.use_deepspeed_implementation, - ds_optimized_lora_config=deepspeed_lora_config, - ds_optimized_quantization_config=quantization_config, - ds_optimized_base_weight_sharding=True, - dtype=torch.bfloat16) - + self.q_proj = get_arctic_linear( + self.hidden_size, + self.num_heads * self.head_dim, + bias=False, + use_deepspeed_implementation=self.use_deepspeed_implementation, + ds_optimized_lora_config=deepspeed_lora_config, + ds_optimized_quantization_config=quantization_config, + ds_optimized_base_weight_sharding=True, + dtype=torch.bfloat16, + ) + self.k_proj = get_arctic_linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=False, + use_deepspeed_implementation=self.use_deepspeed_implementation, + ds_optimized_lora_config=deepspeed_lora_config, + ds_optimized_quantization_config=quantization_config, + ds_optimized_base_weight_sharding=True, + dtype=torch.bfloat16, + ) + self.v_proj = get_arctic_linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=False, + use_deepspeed_implementation=self.use_deepspeed_implementation, + ds_optimized_lora_config=deepspeed_lora_config, + ds_optimized_quantization_config=quantization_config, + ds_optimized_base_weight_sharding=True, + dtype=torch.bfloat16, + ) + self.o_proj = get_arctic_linear( + self.hidden_size, + self.hidden_size, + bias=False, + use_deepspeed_implementation=self.use_deepspeed_implementation, + ds_optimized_lora_config=deepspeed_lora_config, + ds_optimized_quantization_config=quantization_config, + ds_optimized_base_weight_sharding=True, + dtype=torch.bfloat16, + ) + self.rotary_emb = ArcticRotaryEmbedding( self.head_dim, max_position_embeddings=self.max_position_embeddings, @@ -363,7 +429,11 @@ def __init__(self, config: ArcticConfig, layer_idx: Optional[int] = None, **kwa ) def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + return ( + tensor.view(bsz, seq_len, self.num_heads, self.head_dim) + .transpose(1, 2) + .contiguous() + ) def forward( self, @@ -374,7 +444,9 @@ def forward( output_attentions: bool = False, use_cache: bool = False, **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[ + torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]] + ]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" @@ -385,9 +457,15 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -397,19 +475,27 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_usable_length( + kv_seq_len, self.layer_idx + ) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ) / math.sqrt(self.head_dim) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -426,8 +512,12 @@ def forward( attn_weights = attn_weights + attention_mask # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) - attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) attn_output = torch.matmul(attn_weights, value_states) if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): @@ -460,9 +550,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = ( + not is_flash_attn_greater_or_equal_2_10() + ) def forward( self, @@ -487,9 +579,15 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: @@ -499,13 +597,17 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_usable_length( + kv_seq_len, self.layer_idx + ) # Because the input can be padded, the absolute sequence length depends on the max position id. rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) use_sliding_windows = ( _flash_supports_window_size @@ -521,7 +623,9 @@ def forward( if past_key_value is not None: # Activate slicing cache only if the config has a value `sliding_windows` attribute - cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + cache_has_contents = ( + past_key_value.get_seq_length(self.layer_idx) > 0 + ) if ( getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window @@ -543,10 +647,18 @@ def forward( if attention_mask is not None: attention_mask = attention_mask[:, slicing_tokens:] - attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + attention_mask = torch.cat( + [ + attention_mask, + torch.ones_like(attention_mask[:, -1:]), + ], + dim=-1, + ) cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -591,7 +703,9 @@ def forward( use_sliding_windows=use_sliding_windows, ) - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = attn_output.reshape( + bsz, q_len, self.hidden_size + ).contiguous() attn_output = self.o_proj(attn_output) if not output_attentions: @@ -639,8 +753,19 @@ def _flash_attention_forward( # Contains at least one padding token in the sequence if attention_mask is not None: batch_size = query_states.shape[0] - query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, + key_states, + value_states, + attention_mask, + query_length, ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens @@ -671,10 +796,15 @@ def _flash_attention_forward( dropout_p=dropout, softmax_scale=softmax_scale, causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), + window_size=( + self.config.sliding_window, + self.config.sliding_window, + ), ) - attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) else: if not use_sliding_windows: attn_output = flash_attn_func( @@ -693,28 +823,46 @@ def _flash_attention_forward( dropout, softmax_scale=softmax_scale, causal=causal, - window_size=(self.config.sliding_window, self.config.sliding_window), + window_size=( + self.config.sliding_window, + self.config.sliding_window, + ), ) return attn_output - def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape # On the first iteration we need to properly re-create the padding mask # by slicing it on the proper place if kv_seq_len != attention_mask.shape[-1]: attention_mask_num_tokens = attention_mask.shape[-1] - attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + attention_mask = attention_mask[ + :, attention_mask_num_tokens - kv_seq_len : + ] - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask + ) - key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) - value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), + indices_k, + ) if query_length == kv_seq_len: query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + query_layer.reshape( + batch_size * kv_seq_len, num_heads, head_dim + ), + indices_k, ) cu_seqlens_q = cu_seqlens_k max_seqlen_in_batch_q = max_seqlen_in_batch_k @@ -729,7 +877,9 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = ( + unpad_input(query_layer, attention_mask) + ) return ( query_layer, @@ -740,14 +890,17 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query (max_seqlen_in_batch_q, max_seqlen_in_batch_k), ) -def get_arctic_linear(input_dim, - output_dim, - bias=False, - use_deepspeed_implementation=False, - ds_optimized_lora_config=None, - ds_optimized_quantization_config=None, - ds_optimized_base_weight_sharding=False, - dtype=torch.bfloat16): + +def get_arctic_linear( + input_dim, + output_dim, + bias=False, + use_deepspeed_implementation=False, + ds_optimized_lora_config=None, + ds_optimized_quantization_config=None, + ds_optimized_base_weight_sharding=False, + dtype=torch.bfloat16, +): """Can return deepspeed optimized linear if available. Args: input_dim, output_dim, bias, dtype: self explanatory (same as from nn.Linear) @@ -758,9 +911,22 @@ def get_arctic_linear(input_dim, """ if is_deepspeed_available(): if ds_optimized_lora_config is not None: - ds_optimized_lora_config: ds_linear.LoRAConfig = copy.deepcopy(ds_optimized_lora_config) - ds_optimized_lora_config.base_weight_sharding = torch.distributed.get_world_size() if ds_optimized_base_weight_sharding else 1 - return ds_linear.OptimizedLinear(input_dim, output_dim, bias, ds_optimized_lora_config, ds_optimized_quantization_config, dtype=dtype) + ds_optimized_lora_config: ds_linear.LoRAConfig = copy.deepcopy( + ds_optimized_lora_config + ) + ds_optimized_lora_config.base_weight_sharding = ( + torch.distributed.get_world_size() + if ds_optimized_base_weight_sharding + else 1 + ) + return ds_linear.OptimizedLinear( + input_dim, + output_dim, + bias, + ds_optimized_lora_config, + ds_optimized_quantization_config, + dtype=dtype, + ) return nn.Linear(input_dim, output_dim, bias=bias, dtype=dtype) @@ -781,7 +947,9 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[ + torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]] + ]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. logger.warning_once( @@ -803,20 +971,32 @@ def forward( key_states = self.k_proj(hidden_states) value_states = self.v_proj(hidden_states) - query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) - key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) kv_seq_len = key_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_usable_length( + kv_seq_len, self.layer_idx + ) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin, position_ids + ) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -860,44 +1040,63 @@ def forward( class ArcticMLP(nn.Module): - def __init__(self, config: ArcticConfig, - use_deepspeed_implementation=False, - ds_optimized_lora_config=None, - ds_optimized_quantization_config=None, - shard_base_weights_if_doing_lora=False, - is_residual_mlp=False): + def __init__( + self, + config: ArcticConfig, + use_deepspeed_implementation=False, + ds_optimized_lora_config=None, + ds_optimized_quantization_config=None, + shard_base_weights_if_doing_lora=False, + is_residual_mlp=False, + ): """MLP class for Arctic supporting vanilla linear layers as well as some deepspeed optimizations. ds_optimized_lora_config: config of type ds_linear.LoRAConfig that contains lora specific parameter if we want to add lora to this layer. ds_optimized_quantization_config: config of type ds_linear.QuantizationConfig. ds_optimized_base_weight_sharding: bool. If true, the base weight for lora (provided ds_optimized_lora_config is not None) will be sharded across all available gpus in a tensor parallel way. is_residual_mlp: bool. If true, this is MLP inside arctic residual layer which has ffn_dim the same as full intermediate_size. - """ + """ super(ArcticMLP, self).__init__() self.hidden_dim = config.hidden_size - self.ffn_dim = config.intermediate_size if not is_residual_mlp else self.hidden_dim - self.w1 = get_arctic_linear(self.hidden_dim, self.ffn_dim, False, - use_deepspeed_implementation=use_deepspeed_implementation, - ds_optimized_lora_config=ds_optimized_lora_config, - ds_optimized_quantization_config=ds_optimized_quantization_config, - ds_optimized_base_weight_sharding=shard_base_weights_if_doing_lora, - dtype=torch.bfloat16) - self.w2 = get_arctic_linear(self.ffn_dim, self.hidden_dim, False, - use_deepspeed_implementation=use_deepspeed_implementation, - ds_optimized_lora_config=ds_optimized_lora_config, - ds_optimized_quantization_config=ds_optimized_quantization_config, - ds_optimized_base_weight_sharding=shard_base_weights_if_doing_lora, - dtype=torch.bfloat16) - self.w3 = get_arctic_linear(self.hidden_dim, self.ffn_dim, False, - use_deepspeed_implementation=use_deepspeed_implementation, - ds_optimized_lora_config=ds_optimized_lora_config, - ds_optimized_quantization_config=ds_optimized_quantization_config, - ds_optimized_base_weight_sharding=shard_base_weights_if_doing_lora, - dtype=torch.bfloat16) + self.ffn_dim = ( + config.intermediate_size if not is_residual_mlp else self.hidden_dim + ) + self.w1 = get_arctic_linear( + self.hidden_dim, + self.ffn_dim, + False, + use_deepspeed_implementation=use_deepspeed_implementation, + ds_optimized_lora_config=ds_optimized_lora_config, + ds_optimized_quantization_config=ds_optimized_quantization_config, + ds_optimized_base_weight_sharding=shard_base_weights_if_doing_lora, + dtype=torch.bfloat16, + ) + self.w2 = get_arctic_linear( + self.ffn_dim, + self.hidden_dim, + False, + use_deepspeed_implementation=use_deepspeed_implementation, + ds_optimized_lora_config=ds_optimized_lora_config, + ds_optimized_quantization_config=ds_optimized_quantization_config, + ds_optimized_base_weight_sharding=shard_base_weights_if_doing_lora, + dtype=torch.bfloat16, + ) + self.w3 = get_arctic_linear( + self.hidden_dim, + self.ffn_dim, + False, + use_deepspeed_implementation=use_deepspeed_implementation, + ds_optimized_lora_config=ds_optimized_lora_config, + ds_optimized_quantization_config=ds_optimized_quantization_config, + ds_optimized_base_weight_sharding=shard_base_weights_if_doing_lora, + dtype=torch.bfloat16, + ) self.act_fn = ACT2FN[config.hidden_act] def forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3( + hidden_states + ) current_hidden_states = self.w2(current_hidden_states) return current_hidden_states @@ -908,54 +1107,71 @@ def __init__(self, config: ArcticConfig, layer_id: int, **kwargs): self.hidden_dim = config.hidden_size self.num_experts = config.num_local_experts - self.layer_id = layer_id + self.layer_id = layer_id self.top_k = config.num_experts_per_tok - self.is_moe_layer = (layer_id+1) % config.moe_layer_frequency == 0 + self.is_moe_layer = (layer_id + 1) % config.moe_layer_frequency == 0 - self.use_deepspeed_implementation = USE_DEEPSPEED_MOE_ARG in kwargs and kwargs[USE_DEEPSPEED_MOE_ARG] + self.use_deepspeed_implementation = ( + USE_DEEPSPEED_MOE_ARG in kwargs and kwargs[USE_DEEPSPEED_MOE_ARG] + ) if self.use_deepspeed_implementation and MoE is None: raise ValueError("Deepspeed is not installed") quantization_config = kwargs.get(QUANTIZATION_CONFIG, None) deepspeed_lora = kwargs.get(DEEPSPEED_LORA_CONFIG) - if not self.is_moe_layer: # dense, not MoE - self.mlp = ArcticMLP(config, - use_deepspeed_implementation=self.use_deepspeed_implementation, - ds_optimized_quantization_config=quantization_config, - ds_optimized_lora_config=deepspeed_lora, - shard_base_weights_if_doing_lora=True) + if not self.is_moe_layer: # dense, not MoE + self.mlp = ArcticMLP( + config, + use_deepspeed_implementation=self.use_deepspeed_implementation, + ds_optimized_quantization_config=quantization_config, + ds_optimized_lora_config=deepspeed_lora, + shard_base_weights_if_doing_lora=True, + ) else: - if self.use_deepspeed_implementation: # DeepSpeed's MoE - moe_expert_parallel_size = kwargs.get(MOE_EXPERT_PARALLEL_SIZE_ARG, 1) - self.mlp = MoE(self.hidden_dim, - # base weight sharding false for all deepspeed moe calls because it is already sharded - ArcticMLP(config, - use_deepspeed_implementation=True, - ds_optimized_quantization_config=quantization_config, - ds_optimized_lora_config=deepspeed_lora, - shard_base_weights_if_doing_lora=False), - num_experts=config.num_local_experts, - ep_size=moe_expert_parallel_size, - k=config.num_experts_per_tok, - use_residual=False, - capacity_factor=config.moe_train_capacity_factor, - eval_capacity_factor=config.moe_eval_capacity_factor, - enable_expert_tensor_parallelism=config.enable_expert_tensor_parallelism, - min_capacity=config.moe_min_capacity, - drop_tokens=config.moe_token_dropping - ) + if self.use_deepspeed_implementation: # DeepSpeed's MoE + moe_expert_parallel_size = kwargs.get( + MOE_EXPERT_PARALLEL_SIZE_ARG, 1 + ) + self.mlp = MoE( + self.hidden_dim, + # base weight sharding false for all deepspeed moe calls because it is already sharded + ArcticMLP( + config, + use_deepspeed_implementation=True, + ds_optimized_quantization_config=quantization_config, + ds_optimized_lora_config=deepspeed_lora, + shard_base_weights_if_doing_lora=False, + ), + num_experts=config.num_local_experts, + ep_size=moe_expert_parallel_size, + k=config.num_experts_per_tok, + use_residual=False, + capacity_factor=config.moe_train_capacity_factor, + eval_capacity_factor=config.moe_eval_capacity_factor, + enable_expert_tensor_parallelism=config.enable_expert_tensor_parallelism, + min_capacity=config.moe_min_capacity, + drop_tokens=config.moe_token_dropping, + ) else: # "local" MoE implementation - self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.experts = nn.ModuleList([ArcticMLP(config, - use_deepspeed_implementation=self.use_deepspeed_implementation, - ds_optimized_quantization_config=quantization_config, - ds_optimized_lora_config=deepspeed_lora, - shard_base_weights_if_doing_lora=True) for i in range(self.num_experts)]) + self.gate = nn.Linear( + self.hidden_dim, self.num_experts, bias=False + ) + self.experts = nn.ModuleList( + [ + ArcticMLP( + config, + use_deepspeed_implementation=self.use_deepspeed_implementation, + ds_optimized_quantization_config=quantization_config, + ds_optimized_lora_config=deepspeed_lora, + shard_base_weights_if_doing_lora=True, + ) + for i in range(self.num_experts) + ] + ) # if torch.distributed.get_rank() == 0: # deepspeed.runtime.utils.see_memory_usage("", force=True) - # Similar in behavior to transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock.forward but more efficient. def _moe_foreward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape @@ -964,26 +1180,33 @@ def _moe_foreward(self, hidden_states: torch.Tensor) -> torch.Tensor: router_logits = self.gate(hidden_states) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) - routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights, selected_experts = torch.topk( + routing_weights, self.top_k, dim=-1 + ) if self.top_k > 1: routing_weights /= routing_weights.sum(dim=-1, keepdim=True) # we cast back to the input dtype final_hidden_states = torch.zeros( - (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, ) # Matching between experts, tokens, and their top-k rank. For every i, # expert_idx[i] is the rank topk_idx[i] expert for token_idx[i]. expert_idx, token_idx, topk_idx = torch.where( - selected_experts == torch.arange( + selected_experts + == torch.arange( self.num_experts, device=selected_experts.device, ).view((self.num_experts, 1, 1)) ) # Split into one chunk per expert. - bincount = torch.bincount(expert_idx, minlength=self.num_experts).tolist() + bincount = torch.bincount( + expert_idx, minlength=self.num_experts + ).tolist() token_idx = token_idx.split(bincount) topk_idx = topk_idx.split(bincount) @@ -999,15 +1222,26 @@ def _moe_foreward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + current_state = hidden_states[None, top_x_list].reshape( + -1, hidden_dim + ) + current_hidden_states = ( + expert_layer(current_state) + * routing_weights[top_x_list, idx_list, None] + ) # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. - final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states.index_add_( + 0, top_x, current_hidden_states.to(hidden_states.dtype) + ) # torch.distributed.barrier() - final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) - return final_hidden_states, load_balancing_loss_func((router_logits, ), self.num_experts, self.top_k) # ZY: let's directly output the loss to align what we have in ds + final_hidden_states = final_hidden_states.reshape( + batch_size, sequence_length, hidden_dim + ) + return final_hidden_states, load_balancing_loss_func( + (router_logits,), self.num_experts, self.top_k + ) # ZY: let's directly output the loss to align what we have in ds def forward(self, hidden_states: torch.Tensor): if self.is_moe_layer: @@ -1018,7 +1252,9 @@ def forward(self, hidden_states: torch.Tensor): else: return self._moe_foreward(hidden_states) else: - return self.mlp(hidden_states), torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) + return self.mlp(hidden_states), torch.tensor( + 0.0, device=hidden_states.device, dtype=hidden_states.dtype + ) class ArcticDecoderLayer(nn.Module): @@ -1026,23 +1262,37 @@ def __init__(self, config: ArcticConfig, layer_idx: int, **kwargs): super().__init__() self.layer_idx = layer_idx self.hidden_size = config.hidden_size - self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx, **kwargs) + self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation]( + config, layer_idx, **kwargs + ) self.block_sparse_moe = ArcticMoE(config, layer_id=layer_idx, **kwargs) - self.input_layernorm = ArcticRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = ArcticRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.use_deepspeed_implementation = USE_DEEPSPEED_MOE_ARG in kwargs and kwargs[USE_DEEPSPEED_MOE_ARG] + self.input_layernorm = ArcticRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = ArcticRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.use_deepspeed_implementation = ( + USE_DEEPSPEED_MOE_ARG in kwargs and kwargs[USE_DEEPSPEED_MOE_ARG] + ) - self.parallel_attn_mlp_res = config.parallel_attn_mlp_res and self.block_sparse_moe.is_moe_layer # add residual only when it is moe layer + self.parallel_attn_mlp_res = ( + config.parallel_attn_mlp_res and self.block_sparse_moe.is_moe_layer + ) # add residual only when it is moe layer deepspeed_quantization = kwargs.get(DEEPSPEED_QUANTIZATION_CONFIG) deepspeed_lora = kwargs.get(DEEPSPEED_LORA_CONFIG) if self.parallel_attn_mlp_res: - self.residual_layernorm = ArcticRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.residual_mlp = ArcticMLP(config, - use_deepspeed_implementation=self.use_deepspeed_implementation, - is_residual_mlp=True, - ds_optimized_quantization_config=deepspeed_quantization, - ds_optimized_lora_config=deepspeed_lora, - shard_base_weights_if_doing_lora=True) # for the residual layer. always shard the base weight if doing deepspeed lora. + self.residual_layernorm = ArcticRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.residual_mlp = ArcticMLP( + config, + use_deepspeed_implementation=self.use_deepspeed_implementation, + is_residual_mlp=True, + ds_optimized_quantization_config=deepspeed_quantization, + ds_optimized_lora_config=deepspeed_lora, + shard_base_weights_if_doing_lora=True, + ) # for the residual layer. always shard the base weight if doing deepspeed lora. def forward( self, @@ -1053,7 +1303,9 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, **kwargs, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" @@ -1088,7 +1340,7 @@ def forward( hidden_states = residual_input + hidden_states residual_attn = hidden_states - + if self.parallel_attn_mlp_res: # Note the architecture here is that the MOE layers reads the **pre-attention** input while there is a "normal" transformer residual part. # This is to achieve better parallelization. @@ -1099,7 +1351,9 @@ def forward( hidden_states = self.residual_mlp(hidden_states) residual_residual = residual_attn + hidden_states # parallel mlp moe part - hidden_states = self.post_attention_layernorm(residual_input) # parallel attn mlp has the same input + hidden_states = self.post_attention_layernorm( + residual_input + ) # parallel attn mlp has the same input hidden_states, gate_loss = self.block_sparse_moe(hidden_states) hidden_states = residual_residual + hidden_states else: @@ -1173,6 +1427,7 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + MIXTRAL_INPUTS_DOCSTRING = r""" Args: input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): @@ -1243,9 +1498,14 @@ def __init__(self, config: ArcticConfig, **kwargs): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) self.layers = nn.ModuleList( - [ArcticDecoderLayer(config, layer_idx, **kwargs) for layer_idx in range(config.num_hidden_layers)] + [ + ArcticDecoderLayer(config, layer_idx, **kwargs) + for layer_idx in range(config.num_hidden_layers) + ] ) self._attn_implementation = config._attn_implementation self.norm = ArcticRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -1274,23 +1534,39 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MoeModelOutputWithPast]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = ( + use_cache if use_cache is not None else self.config.use_cache ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict + if return_dict is not None + else self.config.use_return_dict + ) # retrieve input_ids and inputs_embeds if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + raise ValueError( + "You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time" + ) elif input_ids is not None: batch_size, seq_length = input_ids.shape elif inputs_embeds is not None: batch_size, seq_length, _ = inputs_embeds.shape else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + raise ValueError( + "You have to specify either decoder_input_ids or decoder_inputs_embeds" + ) past_key_values_length = 0 @@ -1304,13 +1580,24 @@ def forward( if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + past_key_values = DynamicCache.from_legacy_cache( + past_key_values + ) + past_key_values_length = past_key_values.get_usable_length( + seq_length + ) if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device + device = ( + input_ids.device + if input_ids is not None + else inputs_embeds.device + ) position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: @@ -1319,7 +1606,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: + if ( + attention_mask is not None + and self._attn_implementation == "flash_attention_2" + and use_cache + ): is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( @@ -1330,7 +1621,11 @@ def forward( if self._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) elif self._attn_implementation == "sdpa" and not output_attentions: # output_attentions=True can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. @@ -1385,13 +1680,22 @@ def forward( hidden_states = layer_outputs[0] if use_cache: - if hasattr(layer_outputs[2 if output_attentions else 1], 'to_legacy_cache'): - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + if hasattr( + layer_outputs[2 if output_attentions else 1], + "to_legacy_cache", + ): + next_decoder_cache = layer_outputs[ + 2 if output_attentions else 1 + ] else: if next_decoder_cache is None: - next_decoder_cache = [layer_outputs[2 if output_attentions else 1]] + next_decoder_cache = [ + layer_outputs[2 if output_attentions else 1] + ] else: - next_decoder_cache.append(layer_outputs[2 if output_attentions else 1]) + next_decoder_cache.append( + layer_outputs[2 if output_attentions else 1] + ) if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1405,13 +1709,24 @@ def forward( next_cache = None if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache and hasattr(next_decoder_cache, 'to_legacy_cache') else next_decoder_cache - torch.cuda.empty_cache() + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + and hasattr(next_decoder_cache, "to_legacy_cache") + else next_decoder_cache + ) + torch.cuda.empty_cache() if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_router_losses] + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + all_router_losses, + ] if v is not None ) return MoeModelOutputWithPast( @@ -1422,24 +1737,33 @@ def forward( router_logits=all_router_losses, ) + class ArcticForCausalLM(ArcticPreTrainedModel): # TODO(jeffra): update _keys_to_ignore_on_load_unexpected with expert keys not relevant for this rank - _keys_to_ignore_on_load_unexpected = [r"model\.layers\.\d+\.block_sparse_moe\.experts\.\d+\.w\d+\.weight" - r"model\.layers\.\d+\.block_sparse_moe\.gate\.weight"] - _keys_to_ignore_on_load_missing = [r"model\.layers\.\d+\.block_sparse_moe\.mlp\.deepspeed_moe\.experts\.deepspeed_experts\.\d+\.w\d+\.weight", - r"model\.layers\.\d+\.block_sparse_moe\.mlp\.deepspeed_moe\.gate\.wg\.weight"] - _tied_weights_keys = []#["lm_head.weight"] + _keys_to_ignore_on_load_unexpected = [ + r"model\.layers\.\d+\.block_sparse_moe\.experts\.\d+\.w\d+\.weight" + r"model\.layers\.\d+\.block_sparse_moe\.gate\.weight" + ] + _keys_to_ignore_on_load_missing = [ + r"model\.layers\.\d+\.block_sparse_moe\.mlp\.deepspeed_moe\.experts\.deepspeed_experts\.\d+\.w\d+\.weight", + r"model\.layers\.\d+\.block_sparse_moe\.mlp\.deepspeed_moe\.gate\.wg\.weight", + ] + _tied_weights_keys = [] # ["lm_head.weight"] def __init__(self, config, **kwargs): super().__init__(config) self.model = ArcticModel(config, **kwargs) self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.lm_head = nn.Linear( + config.hidden_size, config.vocab_size, bias=False + ) self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_local_experts self.num_experts_per_tok = config.num_experts_per_tok self.use_deepspeed_moe = kwargs.get(USE_DEEPSPEED_MOE_ARG, False) - self.moe_expert_parallel_size = kwargs.get(MOE_EXPERT_PARALLEL_SIZE_ARG, 1) + self.moe_expert_parallel_size = kwargs.get( + MOE_EXPERT_PARALLEL_SIZE_ARG, 1 + ) self.is_deepspeed_lora = kwargs.get(DEEPSPEED_LORA_CONFIG) is not None self.gradient_checkpointing = True # self.shard_base_weights_if_doing_lora = kwargs.get("shard_base_weights_if_doing_lora", False) @@ -1464,10 +1788,9 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - def _expert_number_from_param_name(self, param_name): # example param_name: model.layers.1.block_sparse_moe.experts.10.w1.weight - pattern = r'experts\.(\d+)\.' + pattern = r"experts\.(\d+)\." m = re.search(pattern, param_name) if m: return int(m[1]) @@ -1481,79 +1804,145 @@ def state_dict(self, *args, **kwargs): return state_dict # when trying to construct the deepspeed checkpoint we don't want to gather everything - if not getattr(self, '_gather_expert_params', False): + if not getattr(self, "_gather_expert_params", False): return state_dict - rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 - world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 + rank = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() + else 0 + ) + world_size = ( + torch.distributed.get_world_size() + if torch.distributed.is_initialized() + else 1 + ) # non-lora experts pattern = r"model\.layers\.\d+\.block_sparse_moe\.mlp\.deepspeed_moe\.experts\.deepspeed_experts\.\d+\.w\d+\.weight" expert_params = [s for s in state_dict.keys() if re.search(pattern, s)] for param_name in expert_params: - param_tensor = state_dict[param_name].to('cuda') + param_tensor = state_dict[param_name].to("cuda") output = [torch.zeros_like(param_tensor) for _ in range(world_size)] - torch.distributed.gather(param_tensor, gather_list=output if rank == 0 else None, dst=0, group=None) + torch.distributed.gather( + param_tensor, + gather_list=output if rank == 0 else None, + dst=0, + group=None, + ) # rename from local rank to global rank for gather_rank, gather_param in enumerate(output): - experts_per_rank = self.num_experts // self.moe_expert_parallel_size - new_expert_number = gather_rank * experts_per_rank + self._expert_number_from_param_name(param_name) - new_param_name = re.sub(r'(experts\.)(\d+)(\.)', rf'\g<1>{new_expert_number}\3', param_name) + experts_per_rank = ( + self.num_experts // self.moe_expert_parallel_size + ) + new_expert_number = ( + gather_rank * experts_per_rank + + self._expert_number_from_param_name(param_name) + ) + new_param_name = re.sub( + r"(experts\.)(\d+)(\.)", + rf"\g<1>{new_expert_number}\3", + param_name, + ) state_dict[new_param_name] = gather_param if rank == 0: - print(f"adding to state_dict and renaming: {param_name} -> {new_param_name}") - - # Handle custom LoRA implementation + print( + f"adding to state_dict and renaming: {param_name} -> {new_param_name}" + ) + + # Handle custom LoRA implementation # TODO(rajhans): the part below is untested and shows up when doing lora training. Should not affect inference. if self.is_deepspeed_lora: - for param_name in list(state_dict.keys()): # Use list to avoid RuntimeError due to changing size during iteration - if param_name.endswith("base_weight"): - base_weight = state_dict[param_name].to('cuda') - - # If the base weight is sharded, gather weights from multiple ranks and concatenate - # except if the weights are from deespeed_moe which is not sharded (due to EP). - if self.shard_base_weights_if_doing_lora and 'deepspeed_moe.experts.deepspeed_experts' not in param_name: - gathered_weights = [torch.zeros_like(base_weight, - device=base_weight.device, dtype=base_weight.dtype) for _ in range(world_size)] - torch.distributed.gather(base_weight, gather_list=gathered_weights if rank == 0 else None, dst=0, group=None) + for param_name in list( + state_dict.keys() + ): # Use list to avoid RuntimeError due to changing size during iteration + if param_name.endswith("base_weight"): + base_weight = state_dict[param_name].to("cuda") + + # If the base weight is sharded, gather weights from multiple ranks and concatenate + # except if the weights are from deespeed_moe which is not sharded (due to EP). + if ( + self.shard_base_weights_if_doing_lora + and "deepspeed_moe.experts.deepspeed_experts" + not in param_name + ): + gathered_weights = [ + torch.zeros_like( + base_weight, + device=base_weight.device, + dtype=base_weight.dtype, + ) + for _ in range(world_size) + ] + torch.distributed.gather( + base_weight, + gather_list=gathered_weights if rank == 0 else None, + dst=0, + group=None, + ) base_weight = torch.cat(gathered_weights, dim=1) - - ## The part below is useful if we want to output HF transformer path weights, but commenting it for now - # Merge the LoRA weights into the base weights - # lora_weight_1 = state_dict.get(param_name.replace("base_weight", "lora_weight_1.weight")) - # lora_weight_2 = state_dict.get(param_name.replace("base_weight", "lora_weight_2.weight")) + ## The part below is useful if we want to output HF transformer path weights, but commenting it for now + # Merge the LoRA weights into the base weights + # lora_weight_1 = state_dict.get(param_name.replace("base_weight", "lora_weight_1.weight")) + # lora_weight_2 = state_dict.get(param_name.replace("base_weight", "lora_weight_2.weight")) # if lora_weight_1 is not None and lora_weight_2 is not None: # lora_weights = torch.matmul(lora_weight_2, lora_weight_1) # base_weight += lora_weights # else: - # raise ValueError + # raise ValueError - # # Rename the base weight to weight - # new_param_name = param_name.replace("base_weight", "weight") - # state_dict[new_param_name] = base_weight - - # Remove the base weight from the state dict - # del state_dict[param_name] - return state_dict + # # Rename the base weight to weight + # new_param_name = param_name.replace("base_weight", "weight") + # state_dict[new_param_name] = base_weight + # Remove the base weight from the state dict + # del state_dict[param_name] + return state_dict - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, + ): if not self.use_deepspeed_moe: return super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, ) - world_size = torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 - #TODO(jeffra): currently assumes fine-tuning only on one node, fix for world_size != ep size + world_size = ( + torch.distributed.get_world_size() + if torch.distributed.is_initialized() + else 1 + ) + # TODO(jeffra): currently assumes fine-tuning only on one node, fix for world_size != ep size if self.moe_expert_parallel_size > 1: - assert self.moe_expert_parallel_size == world_size, \ - f"currently only support expert parallel size equal to world size but {self.moe_expert_parallel_size=} and {world_size=}" - - rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + assert ( + self.moe_expert_parallel_size == world_size + ), f"currently only support expert parallel size equal to world size but {self.moe_expert_parallel_size=} and {world_size=}" + + rank = ( + torch.distributed.get_rank() + if torch.distributed.is_initialized() + else 0 + ) num_local_experts = self.num_experts // self.moe_expert_parallel_size - local_expert_range = range(num_local_experts * rank, num_local_experts * rank + num_local_experts) + local_expert_range = range( + num_local_experts * rank, + num_local_experts * rank + num_local_experts, + ) # no deepspeed # model.layers.1.block_sparse_moe.experts.10.w1.weight @@ -1562,7 +1951,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss # model.layers.1.block_sparse_moe.mlp.deepspeed_moe.gate.wg.weight # model.layers.1.block_sparse_moe.mlp.deepspeed_moe.experts.deepspeed_experts.10.w1.weight - gate_pattern = r'model\.layers\.\d+\.block_sparse_moe\.gate\.weight' + gate_pattern = r"model\.layers\.\d+\.block_sparse_moe\.gate\.weight" expert_params_to_keep = [] expert_params_to_remove = [] @@ -1579,30 +1968,45 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss # drop all experts in the state_dict that we don't need locally for param_name in expert_params_to_remove: - print(f'{rank=} dropping {param_name}') + print(f"{rank=} dropping {param_name}") del state_dict[param_name] # rename remaining experts to align with the local config for param_name in expert_params_to_keep: # adjust expert number wrt expert parallelism - new_expert_number = self._expert_number_from_param_name(param_name) % num_local_experts - new_param_name = re.sub(r'(experts\.)(\d+)(\.)', rf'\g<1>{new_expert_number}\3', param_name) + new_expert_number = ( + self._expert_number_from_param_name(param_name) + % num_local_experts + ) + new_param_name = re.sub( + r"(experts\.)(\d+)(\.)", + rf"\g<1>{new_expert_number}\3", + param_name, + ) # use deepspeed moe param path - split_param_name = new_param_name.split('.') - idx = split_param_name.index('experts') - ds_moe_path = "mlp.deepspeed_moe.experts.deepspeed_experts".split('.') - new_param_name = split_param_name[0:idx] + ds_moe_path + split_param_name[idx+1:] + split_param_name = new_param_name.split(".") + idx = split_param_name.index("experts") + ds_moe_path = "mlp.deepspeed_moe.experts.deepspeed_experts".split( + "." + ) + new_param_name = ( + split_param_name[0:idx] + + ds_moe_path + + split_param_name[idx + 1 :] + ) new_param_name = ".".join(new_param_name) - print(f'Deepspeed {rank=}, renaming {param_name} -> {new_param_name}') + print( + f"Deepspeed {rank=}, renaming {param_name} -> {new_param_name}" + ) state_dict[new_param_name] = state_dict.pop(param_name) # rename gate params - ds_suffix = "mlp.deepspeed_moe.gate.wg.weight".split('.') + ds_suffix = "mlp.deepspeed_moe.gate.wg.weight".split(".") for param_name in gate_params: - new_param_name = '.'.join(param_name.split('.')[:4] + ds_suffix) - print(f'Gating: {rank=}, renaming {param_name} -> {new_param_name}') + new_param_name = ".".join(param_name.split(".")[:4] + ds_suffix) + print(f"Gating: {rank=}, renaming {param_name} -> {new_param_name}") state_dict[new_param_name] = state_dict.pop(param_name) # If deepspeed lora is enabled, then we need to rename weight to base_weight. @@ -1613,7 +2017,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss if not param_name.endswith("base_weight"): continue - incoming_param_name = param_name.replace("base_weight", "weight") + incoming_param_name = param_name.replace( + "base_weight", "weight" + ) if incoming_param_name not in state_dict: continue @@ -1621,21 +2027,37 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, miss shape_local = local_state_dict[param_name].shape shape_incoming = incoming_param.shape - if 'deepspeed_moe' in incoming_param_name: - assert shape_local == shape_incoming, "deepspeed moe weights are never sharded" + if "deepspeed_moe" in incoming_param_name: + assert ( + shape_local == shape_incoming + ), "deepspeed moe weights are never sharded" else: - assert shape_incoming[1] == shape_local[1] * world_size, "weights should be sharded equally across world size" - incoming_param = incoming_param[:, rank*shape_local[1]: (rank+1)*shape_local[1]] - print(f'Deepspeed lora: {rank=}, renaming {incoming_param_name} -> {param_name}') + assert ( + shape_incoming[1] == shape_local[1] * world_size + ), "weights should be sharded equally across world size" + incoming_param = incoming_param[ + :, rank * shape_local[1] : (rank + 1) * shape_local[1] + ] + print( + f"Deepspeed lora: {rank=}, renaming {incoming_param_name} -> {param_name}" + ) state_dict[param_name] = incoming_param del state_dict[incoming_param_name] return super()._load_from_state_dict( - state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, ) @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings( + output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) # Ignore copy def forward( self, @@ -1670,12 +2092,22 @@ def forward( "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict + if return_dict is not None + else self.config.use_return_dict ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -1726,7 +2158,12 @@ def forward( ) def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, ): # Omit tokens covered by past_key_values if past_key_values is not None: @@ -1742,8 +2179,13 @@ def prepare_inputs_for_generation( # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) - if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[ + :, -(attention_mask.shape[1] - past_length) : + ] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: @@ -1787,7 +2229,10 @@ def _reorder_cache(past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), ) return reordered_past @@ -1842,7 +2287,11 @@ def forward( config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + return_dict = ( + return_dict + if return_dict is not None + else self.config.use_return_dict + ) transformer_outputs = self.model( input_ids, @@ -1864,19 +2313,28 @@ def forward( batch_size = inputs_embeds.shape[0] if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) if self.config.pad_token_id is None: sequence_lengths = -1 else: if input_ids is not None: # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility - sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id) + .int() + .argmax(-1) + - 1 + ) sequence_lengths = sequence_lengths % input_ids.shape[-1] sequence_lengths = sequence_lengths.to(logits.device) else: sequence_lengths = -1 - pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] loss = None if labels is not None: @@ -1884,7 +2342,9 @@ def forward( if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" @@ -1897,7 +2357,9 @@ def forward( loss = loss_fct(pooled_logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = CrossEntropyLoss() - loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) elif self.config.problem_type == "multi_label_classification": loss_fct = BCEWithLogitsLoss() loss = loss_fct(pooled_logits, labels) @@ -1911,4 +2373,4 @@ def forward( past_key_values=transformer_outputs.past_key_values, hidden_states=transformer_outputs.hidden_states, attentions=transformer_outputs.attentions, - ) \ No newline at end of file + ) diff --git a/moe_infinity/models/modeling_arctic/tokenization_arctic.py b/moe_infinity/models/modeling_arctic/tokenization_arctic.py index 61b9ea4..89bd70a 100644 --- a/moe_infinity/models/modeling_arctic/tokenization_arctic.py +++ b/moe_infinity/models/modeling_arctic/tokenization_arctic.py @@ -6,7 +6,6 @@ class ArcticTokenizer(LlamaTokenizer): - def __init__( self, vocab_file, @@ -54,4 +53,4 @@ def default_chat_template(self): "{% if add_generation_prompt %}" "{{ '<|im_start|>assistant\n' }}" "{% endif %}" - ) \ No newline at end of file + ) diff --git a/moe_infinity/models/modeling_deepseek/__init__.py b/moe_infinity/models/modeling_deepseek/__init__.py new file mode 100644 index 0000000..d8e8499 --- /dev/null +++ b/moe_infinity/models/modeling_deepseek/__init__.py @@ -0,0 +1,3 @@ +from .configuration_deepseek import DeepseekV2Config +from .modeling_deepseek import DeepseekV2ForCausalLM, DeepseekV2MLP, MoEGate, DeepseekV2MoE +from .tokenization_deepseek_fast import DeepseekTokenizerFast \ No newline at end of file diff --git a/moe_infinity/models/modeling_deepseek/configuration_deepseek.py b/moe_infinity/models/modeling_deepseek/configuration_deepseek.py new file mode 100644 index 0000000..82e0f5d --- /dev/null +++ b/moe_infinity/models/modeling_deepseek/configuration_deepseek.py @@ -0,0 +1,206 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +class DeepseekV2Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V2. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 102400): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV2Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + pretraining_tp (`int`, *optional*, defaults to 1): + Experimental feature. Tensor parallelism rank used during pretraining. Please refer to [this + document](https://huggingface.co/docs/transformers/parallelism) to understand more about it. This value is + necessary to ensure exact reproducibility of the pretraining results. Please refer to [this + issue](https://github.com/pytorch/pytorch/issues/76232). + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + + ```python + >>> from transformers import DeepseekV2Model, DeepseekV2Config + + >>> # Initializing a Deepseek-V2 style configuration + >>> configuration = DeepseekV2Config() + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v2" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=102400, + hidden_size=4096, + intermediate_size=11008, + moe_intermediate_size = 1407, + num_hidden_layers=30, + num_attention_heads=32, + num_key_value_heads=32, + n_shared_experts = None, + n_routed_experts = None, + ep_size = 1, + routed_scaling_factor = 1.0, + kv_lora_rank = 512, + q_lora_rank = 1536, + qk_rope_head_dim = 64, + v_head_dim = 128, + qk_nope_head_dim = 128, + topk_method = 'gready', + n_group = None, + topk_group = None, + num_experts_per_tok = None, + moe_layer_freq = 1, + first_k_dense_replace = 0, + norm_topk_prob = False, + scoring_func = 'softmax', + aux_loss_alpha = 0.001, + seq_aux = True, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=100000, + eos_token_id=100001, + pretraining_tp=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + self.aux_loss_alpha = aux_loss_alpha + self.seq_aux = seq_aux + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.pretraining_tp = pretraining_tp + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/moe_infinity/models/modeling_deepseek/modeling_deepseek.py b/moe_infinity/models/modeling_deepseek/modeling_deepseek.py new file mode 100644 index 0000000..847a458 --- /dev/null +++ b/moe_infinity/models/modeling_deepseek/modeling_deepseek.py @@ -0,0 +1,1922 @@ +# coding=utf-8 +# Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch DeepSeek model.""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import ( + ALL_LAYERNORM_LAYERS, + is_torch_greater_or_equal_than_1_13, +) +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from .configuration_deepseek import DeepseekV2Config +import torch.distributed as dist +import numpy as np + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "DeepseekV2Config" + + +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0) + ) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +class DeepseekV2RMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + DeepseekV2RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + +ALL_LAYERNORM_LAYERS.append(DeepseekV2RMSNorm) + + +class DeepseekV2RotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / ( + self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, + device=self.inv_freq.device, + dtype=torch.get_default_dtype(), + ) + self.max_seq_len_cached = None + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq.to(t.device)) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->DeepseekV2 +class DeepseekV2LinearScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): + """DeepseekV2RotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + t = t / self.scaling_factor + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2 +class DeepseekV2DynamicNTKScalingRotaryEmbedding(DeepseekV2RotaryEmbedding): + """DeepseekV2RotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + ): + self.scaling_factor = scaling_factor + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + + if seq_len > self.max_position_embeddings: + base = self.base * ( + (self.scaling_factor * seq_len / self.max_position_embeddings) + - (self.scaling_factor - 1) + ) ** (self.dim / (self.dim - 2)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange( + self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype + ) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + +# Inverse dim formula to find dim based on number of rotations +def yarn_find_correction_dim( + num_rotations, dim, base=10000, max_position_embeddings=2048 +): + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def yarn_find_correction_range( + low_rot, high_rot, dim, base=10000, max_position_embeddings=2048 +): + low = math.floor( + yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def yarn_get_mscale(scale=1, mscale=1): + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +def yarn_linear_ramp_mask(min, max, dim): + if min == max: + max += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding): + + def __init__( + self, + dim, + max_position_embeddings=2048, + base=10000, + device=None, + scaling_factor=1.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1, + mscale_all_dim=0, + ): + self.scaling_factor = scaling_factor + self.original_max_position_embeddings = original_max_position_embeddings + self.beta_fast = beta_fast + self.beta_slow = beta_slow + self.mscale = mscale + self.mscale_all_dim = mscale_all_dim + super().__init__(dim, max_position_embeddings, base, device) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + dim = self.dim + + freq_extra = 1.0 / ( + self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + freq_inter = 1.0 / ( + self.scaling_factor + * self.base + ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ) + + low, high = yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + dim, + self.base, + self.original_max_position_embeddings, + ) + inv_freq_mask = 1.0 - yarn_linear_ramp_mask(low, high, dim // 2).to( + device=device, dtype=torch.float32 + ) + inv_freq = freq_inter * (1 - inv_freq_mask) + freq_extra * inv_freq_mask + self.register_buffer("inv_freq", inv_freq, persistent=False) + + t = torch.arange(seq_len, device=device, dtype=torch.float32) + + freqs = torch.outer(t, inv_freq) + + _mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) + / yarn_get_mscale(self.scaling_factor, self.mscale_all_dim) + ) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer( + "cos_cached", (emb.cos() * _mscale).to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", (emb.sin() * _mscale).to(dtype), persistent=False + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + b, h, s, d = q.shape + q = q.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + b, h, s, d = k.shape + k = k.view(b, h, s, d // 2, 2).transpose(4, 3).reshape(b, h, s, d) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class DeepseekV2MLP(nn.Module): + def __init__(self, config, hidden_size=None, intermediate_size=None): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.intermediate_size = ( + config.intermediate_size if intermediate_size is None else intermediate_size + ) + + self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) + self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, x): + down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + return down_proj + + +class MoEGate(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.top_k = config.num_experts_per_tok + self.n_routed_experts = config.n_routed_experts + self.routed_scaling_factor = config.routed_scaling_factor + self.scoring_func = config.scoring_func + self.alpha = config.aux_loss_alpha + self.seq_aux = config.seq_aux + self.topk_method = config.topk_method + self.n_group = config.n_group + self.topk_group = config.topk_group + + # topk selection algorithm + self.norm_topk_prob = config.norm_topk_prob + self.gating_dim = config.hidden_size + self.weight = nn.Parameter( + torch.empty((self.n_routed_experts, self.gating_dim)) + ) + self.reset_parameters() + + def reset_parameters(self) -> None: + import torch.nn.init as init + + init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, hidden_states): + bsz, seq_len, h = hidden_states.shape + ### compute gating score + hidden_states = hidden_states.view(-1, h) + logits = F.linear( + hidden_states.type(torch.float32), self.weight.type(torch.float32), None + ) + if self.scoring_func == "softmax": + scores = logits.softmax(dim=-1, dtype=torch.float32) + else: + raise NotImplementedError( + f"insupportable scoring function for MoE gating: {self.scoring_func}" + ) + + ### select top-k experts + if self.topk_method == "greedy": + topk_weight, topk_idx = torch.topk( + scores, k=self.top_k, dim=-1, sorted=False + ) + elif self.topk_method == "group_limited_greedy": + group_scores = ( + scores.view(bsz * seq_len, self.n_group, -1).max(dim=-1).values + ) # [n, n_group] + group_idx = torch.topk( + group_scores, k=self.topk_group, dim=-1, sorted=False + )[ + 1 + ] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = ( + group_mask.unsqueeze(-1) + .expand( + bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + ) + .reshape(bsz * seq_len, -1) + ) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weight, topk_idx = torch.topk( + tmp_scores, k=self.top_k, dim=-1, sorted=False + ) + + ### norm gate to sum 1 + if self.top_k > 1 and self.norm_topk_prob: + denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 + topk_weight = topk_weight / denominator + else: + topk_weight = topk_weight * self.routed_scaling_factor + ### expert-level computation auxiliary loss + if self.training and self.alpha > 0.0: + scores_for_aux = scores + aux_topk = self.top_k + # always compute aux loss based on the naive greedy topk method + topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + if self.seq_aux: + scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + ce = torch.zeros( + bsz, self.n_routed_experts, device=hidden_states.device + ) + ce.scatter_add_( + 1, + topk_idx_for_aux_loss, + torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), + ).div_(seq_len * aux_topk / self.n_routed_experts) + aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum( + dim=1 + ).mean() * self.alpha + else: + mask_ce = F.one_hot( + topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts + ) + ce = mask_ce.float().mean(0) + Pi = scores_for_aux.mean(0) + fi = ce * self.n_routed_experts + aux_loss = (Pi * fi).sum() * self.alpha + else: + aux_loss = None + return topk_idx, topk_weight, aux_loss + + +class AddAuxiliaryLoss(torch.autograd.Function): + """ + The trick function of adding auxiliary (aux) loss, + which includes the gradient of the aux loss during backpropagation. + """ + + @staticmethod + def forward(ctx, x, loss): + assert loss.numel() == 1 + ctx.dtype = loss.dtype + ctx.required_aux_loss = loss.requires_grad + return x + + @staticmethod + def backward(ctx, grad_output): + grad_loss = None + if ctx.required_aux_loss: + grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) + return grad_output, grad_loss + + +class DeepseekV2MoE(nn.Module): + """ + A mixed expert module containing shared experts. + """ + + def __init__(self, config): + super().__init__() + self.config = config + self.num_experts_per_tok = config.num_experts_per_tok + + if hasattr(config, "ep_size") and config.ep_size > 1: + assert config.ep_size == dist.get_world_size() + self.ep_size = config.ep_size + self.experts_per_rank = config.n_routed_experts // config.ep_size + self.ep_rank = dist.get_rank() + self.experts = nn.ModuleList( + [ + ( + DeepseekV2MLP( + config, intermediate_size=config.moe_intermediate_size + ) + if i >= self.ep_rank * self.experts_per_rank + and i < (self.ep_rank + 1) * self.experts_per_rank + else None + ) + for i in range(config.n_routed_experts) + ] + ) + else: + self.ep_size = 1 + self.experts_per_rank = config.n_routed_experts + self.ep_rank = 0 + self.experts = nn.ModuleList( + [ + DeepseekV2MLP( + config, intermediate_size=config.moe_intermediate_size + ) + for i in range(config.n_routed_experts) + ] + ) + self.gate = MoEGate(config) + if config.n_shared_experts is not None: + intermediate_size = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = DeepseekV2MLP( + config=config, intermediate_size=intermediate_size + ) + + def forward(self, hidden_states): + identity = hidden_states + orig_shape = hidden_states.shape + topk_idx, topk_weight, aux_loss = self.gate(hidden_states) + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + flat_topk_idx = topk_idx.view(-1) + if self.training: + hidden_states = hidden_states.repeat_interleave( + self.num_experts_per_tok, dim=0 + ) + y = torch.empty_like(hidden_states) + for i, expert in enumerate(self.experts): + y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) + y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y = y.to(hidden_states.dtype).view(*orig_shape) + y = AddAuxiliaryLoss.apply(y, aux_loss) + else: + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) + if self.config.n_shared_experts is not None: + y = y + self.shared_experts(identity) + return y + + @torch.no_grad() + def moe_infer(self, x, topk_ids, topk_weight): + cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts))) + cnts.scatter_(1, topk_ids, 1) + tokens_per_expert = cnts.sum(dim=0) + idxs = topk_ids.view(-1).argsort() + sorted_tokens = x[idxs // topk_ids.shape[1]] + sorted_tokens_shape = sorted_tokens.shape + if self.ep_size > 1: + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + tokens_per_expert_group = tokens_per_expert.new_empty( + tokens_per_expert.shape[0] + ) + dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert) + output_splits = ( + tokens_per_expert_group.view(self.ep_size, -1) + .sum(1) + .cpu() + .numpy() + .tolist() + ) + gathered_tokens = sorted_tokens.new_empty( + tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] + ) + input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() + dist.all_to_all( + list(gathered_tokens.split(output_splits)), + list(sorted_tokens.split(input_split_sizes)), + ) + tokens_per_expert_post_gather = tokens_per_expert_group.view( + self.ep_size, self.experts_per_rank + ).sum(dim=0) + gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + s = 0 + for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): + gatherd_idxs[s : s + k] = i % self.experts_per_rank + s += k + gatherd_idxs = gatherd_idxs.argsort() + sorted_tokens = gathered_tokens[gatherd_idxs] + tokens_per_expert = tokens_per_expert_post_gather + tokens_per_expert = tokens_per_expert.cpu().numpy() + + outputs = [] + start_idx = 0 + for i, num_tokens in enumerate(tokens_per_expert): + end_idx = start_idx + num_tokens + if num_tokens == 0: + continue + expert = self.experts[i + self.ep_rank * self.experts_per_rank] + tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + expert_out = expert(tokens_for_this_expert) + outputs.append(expert_out) + start_idx = end_idx + + outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + if self.ep_size > 1: + new_x = torch.empty_like(outs) + new_x[gatherd_idxs] = outs + gathered_tokens = new_x.new_empty(*sorted_tokens_shape) + dist.all_to_all( + list(gathered_tokens.split(input_split_sizes)), + list(new_x.split(output_splits)), + ) + outs = gathered_tokens + + new_x = torch.empty_like(outs) + new_x[idxs] = outs + final_out = ( + new_x.view(*topk_ids.shape, -1) + .type(topk_weight.dtype) + .mul_(topk_weight.unsqueeze(dim=-1)) + .sum(dim=1) + .type(new_x.dtype) + ) + return final_out + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand( + batch, num_key_value_heads, n_rep, slen, head_dim + ) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 +class DeepseekV2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.q_lora_rank = config.q_lora_rank + self.qk_rope_head_dim = config.qk_rope_head_dim + self.kv_lora_rank = config.kv_lora_rank + self.v_head_dim = config.v_head_dim + self.qk_nope_head_dim = config.qk_nope_head_dim + self.q_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim + + self.is_causal = True + + if self.q_lora_rank is None: + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.q_head_dim, bias=False + ) + else: + self.q_a_proj = nn.Linear( + self.hidden_size, config.q_lora_rank, bias=config.attention_bias + ) + self.q_a_layernorm = DeepseekV2RMSNorm(config.q_lora_rank) + self.q_b_proj = nn.Linear( + config.q_lora_rank, self.num_heads * self.q_head_dim, bias=False + ) + + self.kv_a_proj_with_mqa = nn.Linear( + self.hidden_size, + config.kv_lora_rank + config.qk_rope_head_dim, + bias=config.attention_bias, + ) + self.kv_a_layernorm = DeepseekV2RMSNorm(config.kv_lora_rank) + self.kv_b_proj = nn.Linear( + config.kv_lora_rank, + self.num_heads + * (self.q_head_dim - self.qk_rope_head_dim + self.v_head_dim), + bias=False, + ) + + self.o_proj = nn.Linear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=config.attention_bias, + ) + self._init_rope() + + self.softmax_scale = self.q_head_dim ** (-0.5) + if self.config.rope_scaling is not None: + mscale_all_dim = self.config.rope_scaling.get("mscale_all_dim", 0) + scaling_factor = self.config.rope_scaling["factor"] + if mscale_all_dim: + mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = DeepseekV2RotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = DeepseekV2LinearScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = DeepseekV2DynamicNTKScalingRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "yarn": + kwargs = { + key: self.config.rope_scaling[key] + for key in [ + "original_max_position_embeddings", + "beta_fast", + "beta_slow", + "mscale", + "mscale_all_dim", + ] + if key in self.config.rope_scaling + } + self.rotary_emb = DeepseekV2YarnRotaryEmbedding( + self.qk_rope_head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + **kwargs, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return ( + tensor.view(bsz, seq_len, self.num_heads, self.v_head_dim) + .transpose(1, 2) + .contiguous() + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + attn_weights = ( + torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + ) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + assert attention_mask is not None + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training + ) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.v_head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.v_head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->DeepseekV2 +class DeepseekV2FlashAttention2(DeepseekV2Attention): + """ + DeepseekV2 flash attention module. This module inherits from `DeepseekV2Attention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # DeepseekV2FlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + if self.q_lora_rank is None: + q = self.q_proj(hidden_states) + else: + q = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))) + q = q.view(bsz, q_len, self.num_heads, self.q_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + compressed_kv, k_pe = torch.split( + compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1 + ) + k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) + kv = ( + self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) + .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .transpose(1, 2) + ) + + k_nope, value_states = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) + kv_seq_len = value_states.shape[-2] + + kv_seq_len = value_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) + + query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states[:, :, :, : self.qk_nope_head_dim] = q_nope + query_states[:, :, :, self.qk_nope_head_dim :] = q_pe + + key_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + key_states[:, :, :, : self.qk_nope_head_dim] = k_nope + key_states[:, :, :, self.qk_nope_head_dim :] = k_pe + + if self.q_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (DeepseekV2RMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + # Handle the case where the model is quantized + if hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + elif torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + else: + target_dtype = ( + self.q_proj.weight.dtype + if self.q_lora_rank is None + else self.q_a_proj.weight.dtype + ) + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + softmax_scale=self.softmax_scale, + ) + if self.q_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in DeepseekV2FlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + ( + query_states, + key_states, + value_states, + indices_q, + cu_seq_lens, + max_seq_lens, + ) = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input( + attn_output_unpad, indices_q, batch_size, query_length + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + return attn_output + + def _upad_input( + self, query_layer, key_layer, value_layer, attention_mask, query_length + ): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + indices_k, + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), + indices_k, + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( + query_layer, attention_mask + ) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +ATTENTION_CLASSES = { + "eager": DeepseekV2Attention, + "flash_attention_2": DeepseekV2FlashAttention2, +} + + +class DeepseekV2DecoderLayer(nn.Module): + def __init__(self, config: DeepseekV2Config, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + + self.mlp = ( + DeepseekV2MoE(config) + if ( + config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0 + ) + else DeepseekV2MLP(config) + ) + self.input_layernorm = DeepseekV2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.post_attention_layernorm = DeepseekV2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +DeepseekV2_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`DeepseekV2Config`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2PreTrainedModel(PreTrainedModel): + config_class = DeepseekV2Config + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["DeepseekV2DecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +DeepseekV2_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance; + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare DeepseekV2 Model outputting raw hidden-states without any specific head on top.", + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2Model(DeepseekV2PreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`] + + Args: + config: DeepseekV2Config + """ + + def __init__(self, config: DeepseekV2Config): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + self.layers = nn.ModuleList( + [ + DeepseekV2DecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time" + ) + elif input_ids is not None: + batch_size, seq_length = input_ids.shape[:2] + elif inputs_embeds is not None: + batch_size, seq_length = inputs_embeds.shape[:2] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers." + ) + use_cache = False + + past_key_values_length = 0 + if use_cache: + use_legacy_cache = not isinstance(past_key_values, Cache) + if use_legacy_cache: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + past_key_values_length = past_key_values.get_usable_length(seq_length) + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if self._use_flash_attention_2: + # 2d mask is passed through the layers + attention_mask = ( + attention_mask + if (attention_mask is not None and 0 in attention_mask) + else None + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, + (batch_size, seq_length), + inputs_embeds, + past_key_values_length, + ) + + # embed positions + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + attention_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = None + if use_cache: + next_cache = ( + next_decoder_cache.to_legacy_cache() + if use_legacy_cache + else next_decoder_cache + ) + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + if v is not None + ) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = DeepseekV2Model(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + @replace_return_docstrings( + output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC + ) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, transformers., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, transformers., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, DeepseekV2ForCausalLM + + >>> model = DeepseekV2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + **kwargs, + ): + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # input) + if ( + attention_mask is not None + and attention_mask.shape[1] > input_ids.shape[1] + ): + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple( + past_state.index_select(0, beam_idx.to(past_state.device)) + for past_state in layer_past + ), + ) + return reordered_past + + +@add_start_docstrings( + """ + The DeepseekV2 Model transformer with a sequence classification head on top (linear layer). + + [`DeepseekV2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + DeepseekV2_START_DOCSTRING, +) +class DeepseekV2ForSequenceClassification(DeepseekV2PreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = DeepseekV2Model(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, transformers., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError( + "Cannot handle batch sizes > 1 if no padding token is defined." + ) + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + sequence_lengths = ( + torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + ).to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[ + torch.arange(batch_size, device=logits.device), sequence_lengths + ] + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and ( + labels.dtype == torch.long or labels.dtype == torch.int + ): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct( + pooled_logits.view(-1, self.num_labels), labels.view(-1) + ) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/moe_infinity/models/modeling_deepseek/tokenization_deepseek_fast.py b/moe_infinity/models/modeling_deepseek/tokenization_deepseek_fast.py new file mode 100644 index 0000000..d243771 --- /dev/null +++ b/moe_infinity/models/modeling_deepseek/tokenization_deepseek_fast.py @@ -0,0 +1,38 @@ +from typing import List, Optional, Union + + +from transformers.models.llama import LlamaTokenizerFast + + +class DeepseekTokenizerFast(LlamaTokenizerFast): + + def convert_ids_to_tokens( + self, ids: Union[int, List[int]], skip_special_tokens: bool = False + ) -> Union[str, List[str]]: + """ + Converts a single index or a sequence of indices in a token or a sequence of tokens, using the vocabulary and + added tokens. + + Args: + ids (`int` or `List[int]`): + The token id (or token ids) to convert to tokens. + skip_special_tokens (`bool`, *optional*, defaults to `False`): + Whether or not to remove special tokens in the decoding. + + Returns: + `str` or `List[str]`: The decoded token(s). + """ + if isinstance(ids, int): + return self._convert_id_to_token(ids) + tokens = [] + for index in ids: + index = int(index) + if skip_special_tokens and index in self.all_special_ids: + continue + token = self._tokenizer.id_to_token(index) + tokens.append(token if token is not None else "") + return tokens + + def _convert_id_to_token(self, index: int) -> Optional[str]: + token = self._tokenizer.id_to_token(int(index)) + return token if token is not None else "" diff --git a/moe_infinity/models/modeling_grok/configuration_grok1.py b/moe_infinity/models/modeling_grok/configuration_grok1.py index c10f895..a30f97b 100644 --- a/moe_infinity/models/modeling_grok/configuration_grok1.py +++ b/moe_infinity/models/modeling_grok/configuration_grok1.py @@ -28,7 +28,7 @@ def __init__( num_experts=8, output_router_logits=False, router_aux_loss_coef=0.001, - **kwargs + **kwargs, ): self.vocab_size = vocab_size self.attn_output_multiplier = attn_output_multiplier @@ -59,4 +59,4 @@ def __init__( eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, - ) \ No newline at end of file + ) diff --git a/moe_infinity/models/modeling_grok/modeling_grok1.py b/moe_infinity/models/modeling_grok/modeling_grok1.py index bf2ee35..5ae2a88 100644 --- a/moe_infinity/models/modeling_grok/modeling_grok1.py +++ b/moe_infinity/models/modeling_grok/modeling_grok1.py @@ -7,16 +7,19 @@ from transformers.utils import logging try: - from transformers.modeling_attn_mask_utils import \ - _prepare_4d_causal_attention_mask + from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + ) HAS_MASK_UTILS = True except ImportError: HAS_MASK_UTILS = False from .configuration_grok1 import Grok1Config -from .modeling_grok1_outputs import (MoeCausalLMOutputWithPast, - MoeModelOutputWithPast) +from .modeling_grok1_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) logger = logging.get_logger(__name__) @@ -69,7 +72,8 @@ def load_balancing_loss_func( router_prob_per_group_and_expert = torch.mean(routing_weights, axis=-1) return torch.mean( - tokens_per_group_and_expert * router_prob_per_group_and_expert.unsqueeze(-1) + tokens_per_group_and_expert + * router_prob_per_group_and_expert.unsqueeze(-1) ) * (num_experts**2) @@ -85,7 +89,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + return hidden_states.reshape( + batch, num_key_value_heads * n_rep, slen, head_dim + ) class RMSNorm(nn.Module): @@ -106,7 +112,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) hidden_states = self.scale * hidden_states return hidden_states.to(input_dtype) @@ -140,13 +148,19 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + self.register_buffer( + "cos_cached", emb.cos().to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin().to(dtype), persistent=False + ) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + self._set_cos_sin_cache( + seq_len=seq_len, device=x.device, dtype=x.dtype + ) return ( self.cos_cached[:seq_len].to(dtype=x.dtype), @@ -217,14 +231,18 @@ def __init__( f" and `num_heads`: {self.num_heads})." ) - self.q_proj = nn.Linear(hidden_size, self.num_heads * self.head_dim, bias=False) + self.q_proj = nn.Linear( + hidden_size, self.num_heads * self.head_dim, bias=False + ) self.k_proj = nn.Linear( hidden_size, self.num_key_value_heads * self.head_dim, bias=False ) self.v_proj = nn.Linear( hidden_size, self.num_key_value_heads * self.head_dim, bias=False ) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, hidden_size, bias=False) + self.o_proj = nn.Linear( + self.num_heads * self.head_dim, hidden_size, bias=False + ) self.rotary_emb = RotaryEmbedding( self.head_dim, @@ -240,7 +258,9 @@ def forward( output_attentions: bool = False, use_cache: bool = False, **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[ + torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]] + ]: bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -277,11 +297,13 @@ def forward( key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) - attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)).to( - torch.float - ) + attn_weights = torch.matmul( + query_states, key_states.transpose(2, 3) + ).to(torch.float) attn_weights = attn_weights * self.attn_output_multiplier - attn_weights = self.max_attn_val * F.tanh(attn_weights / self.max_attn_val) + attn_weights = self.max_attn_val * F.tanh( + attn_weights / self.max_attn_val + ) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): raise ValueError( @@ -330,9 +352,9 @@ def __init__( self.act_fn = nn.GELU() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - current_hidden_states = self.act_fn(self.linear(hidden_states)) * self.linear_v( - hidden_states - ) + current_hidden_states = self.act_fn( + self.linear(hidden_states) + ) * self.linear_v(hidden_states) current_hidden_states = self.linear_1(current_hidden_states) return current_hidden_states @@ -392,7 +414,9 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_state = hidden_states[None, top_x_list].reshape( + -1, hidden_dim + ) current_hidden_states = ( expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] @@ -432,7 +456,9 @@ def __init__( attn_output_multiplier=attn_output_multiplier, max_attn_val=max_attn_val, ) - self.moe_block = MoeBlock(hidden_size, intermediate_size, num_experts, top_k) + self.moe_block = MoeBlock( + hidden_size, intermediate_size, num_experts, top_k + ) self.pre_attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps) self.post_attn_norm = RMSNorm(hidden_size, eps=rms_norm_eps) self.pre_moe_norm = RMSNorm(hidden_size, eps=rms_norm_eps) @@ -530,14 +556,18 @@ def _make_causal_mask( # Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): +def _expand_mask( + mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None +): """ Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. """ bsz, src_len = mask.size() tgt_len = tgt_len if tgt_len is not None else src_len - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + expanded_mask = ( + mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + ) inverted_mask = 1.0 - expanded_mask @@ -634,10 +664,14 @@ def forward( if output_hidden_states is not None else self.config.output_hidden_states ) - use_cache = use_cache if use_cache is not None else self.config.use_cache + use_cache = ( + use_cache if use_cache is not None else self.config.use_cache + ) return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + return_dict + if return_dict is not None + else self.config.use_return_dict ) # retrieve input_ids and inputs_embeds @@ -650,7 +684,9 @@ def forward( elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.shape[:2] else: - raise ValueError("You have to specify either input_ids or inputs_embeds") + raise ValueError( + "You have to specify either input_ids or inputs_embeds" + ) seq_length_with_past = seq_length past_key_values_length = 0 @@ -659,7 +695,11 @@ def forward( seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device + device = ( + input_ids.device + if input_ids is not None + else inputs_embeds.device + ) position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, @@ -723,7 +763,9 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, past_key_value, output_attentions) + return module( + *inputs, past_key_value, output_attentions + ) return custom_forward @@ -746,7 +788,9 @@ def custom_forward(*inputs): hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache += ( + layer_outputs[2 if output_attentions else 1], + ) if output_attentions: all_self_attns += (layer_outputs[1],) @@ -790,7 +834,9 @@ def __init__(self, config: Grok1Config, **kwargs): self.model = Grok1Model(config) self.vocab_size = config.vocab_size self.output_multiplier_scale = config.output_multiplier_scale - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.lm_head = nn.Linear( + config.hidden_size, config.vocab_size, bias=False + ) self.router_aux_loss_coef = config.router_aux_loss_coef self.num_experts = config.num_experts self.num_experts_per_tok = config.num_experts_per_tok @@ -845,7 +891,9 @@ def forward( else self.config.output_hidden_states ) return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + return_dict + if return_dict is not None + else self.config.use_return_dict ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) diff --git a/moe_infinity/models/modeling_grok/modeling_grok1_outputs.py b/moe_infinity/models/modeling_grok/modeling_grok1_outputs.py index 48cb7eb..21a495a 100644 --- a/moe_infinity/models/modeling_grok/modeling_grok1_outputs.py +++ b/moe_infinity/models/modeling_grok/modeling_grok1_outputs.py @@ -103,4 +103,4 @@ class MoeCausalLMOutputWithPast(ModelOutput): past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None - router_logits: Optional[Tuple[torch.FloatTensor]] = None \ No newline at end of file + router_logits: Optional[Tuple[torch.FloatTensor]] = None diff --git a/moe_infinity/models/nllb_moe.py b/moe_infinity/models/nllb_moe.py index 894fdb1..1d5c1eb 100644 --- a/moe_infinity/models/nllb_moe.py +++ b/moe_infinity/models/nllb_moe.py @@ -4,12 +4,13 @@ # TorchMoE Team from typing import Dict, Optional + import torch import torch.nn as nn from transformers import NllbMoeConfig from transformers.models.nllb_moe.modeling_nllb_moe import ( - NllbMoeTop2Router, NllbMoeDenseActDense, + NllbMoeTop2Router, ) from moe_infinity.utils import ArcherConfig @@ -18,7 +19,6 @@ class SyncNllbMoeSparseMLP(nn.Module): - archer_config: ArcherConfig = None layer_id: int = None @@ -43,36 +43,51 @@ def __init__( self.archer_engine = None self.expert_tensor_ids: Dict[int, int] = None - def forward(self, - hidden_states: torch.Tensor, - padding_mask: Optional[torch.Tensor] = None): + def forward( + self, + hidden_states: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + ): batch_size, sequence_length, hidden_dim = hidden_states.shape top_1_mask, router_probs = self.router(hidden_states, padding_mask) combining_weights = router_probs.reshape( - (batch_size, sequence_length, self.num_experts)) + (batch_size, sequence_length, self.num_experts) + ) router_mask = combining_weights.bool() next_states = torch.zeros_like(hidden_states) top_1_expert_index = torch.argmax(top_1_mask, dim=-1) - logits_except_top_1 = router_probs.masked_fill(top_1_mask.bool(), float("-inf")) + logits_except_top_1 = router_probs.masked_fill( + top_1_mask.bool(), float("-inf") + ) top_2_expert_index = torch.argmax(logits_except_top_1, dim=-1) # top_2_mask = torch.nn.functional.one_hot(top_2_expert_index, num_classes=self.num_experts) - expert_index = torch.stack([top_1_expert_index, top_2_expert_index], dim=-1) + expert_index = torch.stack( + [top_1_expert_index, top_2_expert_index], dim=-1 + ) expert_index = expert_index.reshape(batch_size, sequence_length, 2) for i in range(batch_size): seq_id = self.seq_id_list[i] - expert_matrix = self.expert_predictor.predict(seq_id, expert_index[i], self.layer_id) - self.expert_prefetcher.prefetch_experts(self.layer_id, expert_matrix) - - results = self.expert_executor.dispatch_local(hidden_states, router_mask, self.layer_id) + expert_matrix = self.expert_predictor.predict( + seq_id, expert_index[i], self.layer_id + ) + self.expert_prefetcher.prefetch_experts( + self.layer_id, expert_matrix + ) + + results = self.expert_executor.dispatch_local( + hidden_states, router_mask, self.layer_id + ) for output, _, idx, _ in results: token_indices = router_mask[:, idx].bool() weights = combining_weights[..., idx] - next_states[token_indices] += torch.einsum("b,be->be", weights[token_indices], output.to(weights.device)) + next_states[token_indices] += torch.einsum( + "b,be->be", weights[token_indices], output.to(weights.device) + ) # for expert_id, expert in self.experts.items(): # idx = int(expert_id.split("_")[-1]) @@ -86,6 +101,7 @@ def forward(self, next_states[next_states == 0] = hidden_states[next_states == 0] hidden_states = next_states - return hidden_states, (router_probs.to("cuda:0", non_blocking=True), - top_1_expert_index.to("cuda:0", - non_blocking=True)) + return hidden_states, ( + router_probs.to("cuda:0", non_blocking=True), + top_1_expert_index.to("cuda:0", non_blocking=True), + ) diff --git a/moe_infinity/models/switch_transformers.py b/moe_infinity/models/switch_transformers.py index 6a1a061..d974a6f 100644 --- a/moe_infinity/models/switch_transformers.py +++ b/moe_infinity/models/switch_transformers.py @@ -4,14 +4,16 @@ # TorchMoE Team from typing import Dict + import torch import torch.nn as nn from transformers import SwitchTransformersConfig +from transformers.activations import ACT2FN from transformers.models.switch_transformers.modeling_switch_transformers import ( - SwitchTransformersTop1Router, SwitchTransformersDenseActDense, + SwitchTransformersTop1Router, ) -from transformers.activations import ACT2FN + from ..memory import ExpertPredictor from ..utils import ArcherConfig @@ -19,7 +21,6 @@ class SwitchTransformersDenseGatedActDense(nn.Module): - def __init__(self, config: SwitchTransformersConfig): super().__init__() self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) @@ -71,7 +72,7 @@ def __init__( self.expert_predictor: ExpertPredictor = None def forward(self, hidden_states): - # Step 1: Get the router_mask from the router as wel as the probabilities + # Step 1: Get the router_mask from the router as well as the probabilities router_mask, router_probs, router_logits = self.router(hidden_states) expert_index = torch.argmax(router_mask, dim=-1) @@ -84,16 +85,21 @@ def forward(self, hidden_states): expert_index = expert_index.reshape(batch_size, -1) for i in range(batch_size): seq_id = self.seq_id_list[i] - expert_matrix = self.expert_predictor.predict(seq_id, expert_index[i], self.layer_id) - self.expert_prefetcher.prefetch_experts(self.layer_id, expert_matrix) + expert_matrix = self.expert_predictor.predict( + seq_id, expert_index[i], self.layer_id + ) + self.expert_prefetcher.prefetch_experts( + self.layer_id, expert_matrix + ) - - results = self.expert_executor.dispatch_local(hidden_states, router_mask, self.layer_id) + results = self.expert_executor.dispatch_local( + hidden_states, router_mask, self.layer_id + ) for output, _, idx, _ in results: token_indices = router_mask[:, :, idx].bool() next_states[token_indices] = output.to(next_states.device) - + # for expert_id, expert in self.experts.items(): # idx = int(expert_id.split("_")[-1]) # token_indices = router_mask[:, :, idx].bool() @@ -101,5 +107,7 @@ def forward(self, hidden_states): # next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.device) hidden_states = router_probs * next_states - return hidden_states, (router_logits.to("cuda:0", non_blocking=True), - expert_index.to("cuda:0", non_blocking=True)) + return hidden_states, ( + router_logits.to("cuda:0", non_blocking=True), + expert_index.to("cuda:0", non_blocking=True), + ) diff --git a/moe_infinity/runtime/__init__.py b/moe_infinity/runtime/__init__.py index 7ce8b98..fdceb1f 100644 --- a/moe_infinity/runtime/__init__.py +++ b/moe_infinity/runtime/__init__.py @@ -1 +1 @@ -from .model_offload import OffloadEngine \ No newline at end of file +from .model_offload import OffloadEngine diff --git a/moe_infinity/runtime/model_offload.py b/moe_infinity/runtime/model_offload.py index 1802b53..241b9be 100644 --- a/moe_infinity/runtime/model_offload.py +++ b/moe_infinity/runtime/model_offload.py @@ -3,60 +3,55 @@ # TorchMoE Team +import functools import gc +import json import os -import numpy as np -import math -import torch.distributed as dist -from torch.distributed import rpc -from auto_gptq.nn_modules.qlinear.qlinear_cuda import QuantLinear -from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import QuantLinear as QuantLinearOld +import re +from typing import Callable, Dict, Type, Union import torch -import functools -import json +import transformers +# import torch.distributed as dist +# from torch.distributed import rpc +from auto_gptq.nn_modules.qlinear.qlinear_cuda import QuantLinear +from auto_gptq.nn_modules.qlinear.qlinear_cuda_old import ( + QuantLinear as QuantLinearOld, +) +from safetensors import safe_open from tqdm import tqdm +from transformers.modeling_utils import PretrainedConfig, PreTrainedModel -from moe_infinity.ops.op_builder.prefetch import PrefetchBuilder +import moe_infinity +from moe_infinity.common import parse_expert_type +from moe_infinity.distributed import DistributedExpertExecutor +from moe_infinity.memory import ExpertPredictor, ExpertPrefetcher, ExpertTracer from moe_infinity.models import ( - SyncSwitchTransformersSparseMLP, - SyncNllbMoeSparseMLP, - SyncMixtralSparseMoeBlock, - SyncGrokMoeBlock, + DeepseekV2MoEBlock, SyncArcticMoeBlock, + SyncGrokMoeBlock, + SyncMixtralSparseMoeBlock, + SyncNllbMoeSparseMLP, + SyncSwitchTransformersSparseMLP, ) -from moe_infinity.utils import ArcherConfig -from moe_infinity.utils.arguments import copy_args_to_device, copy_kwargs_to_device -from moe_infinity.distributed import DistributedExpertExecutor - -from moe_infinity.memory import ExpertPrefetcher -import moe_infinity +from moe_infinity.ops.op_builder.prefetch import PrefetchBuilder from moe_infinity.utils import ( - parse_moe_param, - parse_expert_id, + ArcherConfig, parse_expert_dtype, + parse_expert_id, + parse_moe_param, ) -from moe_infinity.common import parse_expert_type -from moe_infinity.memory import ExpertTracer, ExpertPredictor, ExpertCache - -from typing import Dict, Type, Union -from transformers import ( - AutoConfig, +from moe_infinity.utils.arguments import ( + copy_args_to_device, + copy_kwargs_to_device, ) -from transformers.modeling_utils import PreTrainedModel, PretrainedConfig -import transformers -from typing import Callable - -from safetensors import safe_open - -import re use_jit = False try: import moe_infinity.ops.prefetch.prefetch_op as prefetch_op except ImportError: - print(f"Do not detect pre-installed ops, use JIT mode") + print("Do not detect pre-installed ops, use JIT mode") use_jit = True @@ -71,7 +66,6 @@ class OffloadEngine(object): config = {} def __init__(self, capacity, config: PretrainedConfig): - self.offload_exemption = set() self.expert_modules = [] @@ -91,9 +85,10 @@ def __init__(self, capacity, config: PretrainedConfig): # def init_trace(self, trace_path: str): def init( - self, cls: Type[PreTrainedModel], ar_config: Union[str, Dict, ArcherConfig] + self, + cls: Type[PreTrainedModel], + ar_config: Union[str, Dict, ArcherConfig], ): - self.cls = cls self.name_id_map = {} self.tensor_id_map = {} @@ -157,7 +152,9 @@ def init( # ): # os.remove(_archer_config.perfect_cache_file) - self.expert_executor = DistributedExpertExecutor(archer_config=_archer_config) + self.expert_executor = DistributedExpertExecutor( + archer_config=_archer_config + ) # self.expert_prefetcher = ExpertPrefetcher(self.config) # self.device_map_manager = DeviceMapManager(archer_config=_archer_config) @@ -168,9 +165,7 @@ def init( return self def __enter__(self): - def do_nothing_decorator(orig_func: Callable) -> Callable: - @functools.wraps(orig_func) def do_nothing(*args, **kwargs): pass @@ -186,31 +181,35 @@ def archer_post_init(cls, *args, **kwargs): return archer_post_init def torch_index_select_decorator(orig_torch_index_select: Callable): - @functools.wraps(orig_torch_index_select) def archer_torch_index_select(input, dim, index): - return orig_torch_index_select(input, dim, index.to(input.device)).to( - "cuda:0" - ) + return orig_torch_index_select( + input, dim, index.to(input.device) + ).to("cuda:0") return archer_torch_index_select def apply_to_model_decorator(orig_apply_to_model: Callable) -> Callable: - @functools.wraps(orig_apply_to_model) def archer_apply_to_model(cls, fn): for name, param in cls.named_parameters(recurse=True): if name not in self.name_id_map: continue param.data = torch.zeros( - 1, dtype=param.dtype, device=param.device, pin_memory=True + 1, + dtype=param.dtype, + device=param.device, + pin_memory=True, ) for name, buffer in cls.named_buffers(recurse=True): if name not in self.name_id_map: continue buffer.data = torch.zeros( - 1, dtype=buffer.dtype, device=buffer.device, pin_memory=True + 1, + dtype=buffer.dtype, + device=buffer.device, + pin_memory=True, ) return archer_apply_to_model @@ -230,7 +229,6 @@ def archer_apply_to_model(cls, fn): # return archer_load_pretrained_model def init_decorator(orig_init: Callable) -> Callable: - @functools.wraps(orig_init) def archer_init(cls, config, *args, **kwargs): # self.config = config @@ -249,7 +247,6 @@ def archer_init(cls, config, *args, **kwargs): # return archer_config def param_init_decorator(orig_param_init: Callable) -> Callable: - @functools.wraps(orig_param_init) def archer_param_init(cls, *args, **kwargs): orig_param_init(cls, *args, **kwargs) @@ -257,34 +254,42 @@ def archer_param_init(cls, *args, **kwargs): cls.param_real_shape = {} for name, param in cls.named_parameters(recurse=False): cls.param_real_shape[name] = param.shape - param.data = torch.zeros(1, dtype=param.dtype, device=param.device) + param.data = torch.zeros( + 1, dtype=param.dtype, device=param.device + ) self.model_create_counter.update(1) for name, buf in cls.named_buffers(recurse=False): cls.param_real_shape[name] = buf.shape - buf.data = torch.zeros(1, dtype=buf.dtype, device=buf.device) + buf.data = torch.zeros( + 1, dtype=buf.dtype, device=buf.device + ) self.model_create_counter.update(1) return archer_param_init - - def cast_classifier_decorator(orig_cast_classifier: Callable) -> Callable: + def cast_classifier_decorator( + orig_cast_classifier: Callable, + ) -> Callable: @functools.wraps(orig_cast_classifier) def archer_cast_classifier(cls, *args, **kwargs): orig_data_ptr = cls.classifier.weight.data.data_ptr() if orig_data_ptr in self.offload_set: - self.offload_set.remove(cls.classifier.weight.data.data_ptr()) + self.offload_set.remove( + cls.classifier.weight.data.data_ptr() + ) orig_cast_classifier(cls, *args, **kwargs) new_data_ptr = cls.classifier.weight.data.data_ptr() self.offload_set.add(cls.classifier.weight.data.data_ptr()) - self.archer_engine.update_tensor_map(orig_data_ptr, new_data_ptr) + self.archer_engine.update_tensor_map( + orig_data_ptr, new_data_ptr + ) else: orig_cast_classifier(cls, *args, **kwargs) self.offload_set.add(cls.classifier.weight.data.data_ptr()) return archer_cast_classifier - - + # GPTQ Override QuantLinear._old_init = QuantLinear.__init__ QuantLinear.__init__ = param_init_decorator(QuantLinear.__init__) @@ -303,13 +308,17 @@ def archer_cast_classifier(cls, *args, **kwargs): # transformers.modeling_utils.old_load_state_dict = ( # transformers.modeling_utils.load_state_dict) # transformers.modeling_utils.load_state_dict = load_state_dict - torch.nn.modules.module.Module._old_apply = torch.nn.modules.module.Module.apply + torch.nn.modules.module.Module._old_apply = ( + torch.nn.modules.module.Module.apply + ) torch.nn.modules.module.Module.apply = apply_to_model_decorator( torch.nn.modules.module.Module._old_apply ) torch._old_index_select = torch.index_select - torch.index_select = torch_index_select_decorator(torch._old_index_select) + torch.index_select = torch_index_select_decorator( + torch._old_index_select + ) torch.Tensor._old_index_select = torch.Tensor.index_select torch.Tensor.index_select = torch_index_select_decorator( torch.Tensor._old_index_select @@ -318,7 +327,9 @@ def archer_cast_classifier(cls, *args, **kwargs): self.cls._old_post_init = self.cls.post_init self.cls.post_init = post_init_decorator(self.cls._old_post_init) PreTrainedModel._old_post_init = PreTrainedModel.post_init - PreTrainedModel.post_init = post_init_decorator(PreTrainedModel._old_post_init) + PreTrainedModel.post_init = post_init_decorator( + PreTrainedModel._old_post_init + ) # for all the modules in torch.nn, add post_init method # assert False, torch.nn.modules.__dict__ @@ -341,17 +352,17 @@ def archer_cast_classifier(cls, *args, **kwargs): if hasattr(module, "reset_parameters"): module._old_reset_parameters = module.reset_parameters - module.reset_parameters = do_nothing_decorator(module.reset_parameters) - - transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._old_cast_classifier = transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._cast_classifier - transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._cast_classifier = cast_classifier_decorator(transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._cast_classifier) + module.reset_parameters = do_nothing_decorator( + module.reset_parameters + ) - transformers.models.switch_transformers.modeling_switch_transformers._old_sparse_mlp = ( - transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersSparseMLP - ) - transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersSparseMLP = ( - SyncSwitchTransformersSparseMLP + transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._old_cast_classifier = transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._cast_classifier + transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._cast_classifier = cast_classifier_decorator( + transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._cast_classifier ) + + transformers.models.switch_transformers.modeling_switch_transformers._old_sparse_mlp = transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersSparseMLP + transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersSparseMLP = SyncSwitchTransformersSparseMLP transformers.models.nllb_moe.modeling_nllb_moe._old_sparse_mlp = ( transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeSparseMLP ) @@ -364,20 +375,36 @@ def archer_cast_classifier(cls, *args, **kwargs): transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock = ( SyncMixtralSparseMoeBlock ) - moe_infinity.models.modeling_grok.modeling_grok1._old_sparse_mlp = moe_infinity.models.modeling_grok.MoeBlock - moe_infinity.models.modeling_grok.modeling_grok1.MoeBlock = SyncGrokMoeBlock - - moe_infinity.models.modeling_arctic._old_sparse_mlp = moe_infinity.models.modeling_arctic.ArcticMoE - moe_infinity.models.modeling_arctic.modeling_arctic.ArcticMoE = SyncArcticMoeBlock - - def from_pretrained_decorator(orig_from_pretrained: Callable) -> Callable: + moe_infinity.models.modeling_grok.modeling_grok1._old_sparse_mlp = ( + moe_infinity.models.modeling_grok.MoeBlock + ) + moe_infinity.models.modeling_grok.modeling_grok1.MoeBlock = ( + SyncGrokMoeBlock + ) + + moe_infinity.models.modeling_arctic._old_sparse_mlp = ( + moe_infinity.models.modeling_arctic.ArcticMoE + ) + moe_infinity.models.modeling_arctic.modeling_arctic.ArcticMoE = ( + SyncArcticMoeBlock + ) + + moe_infinity.models.modeling_deepseek._old_sparse_mlp = ( + moe_infinity.models.modeling_deepseek.DeepseekV2MoE + ) + moe_infinity.models.modeling_deepseek.modeling_deepseek.DeepseekV2MoE = DeepseekV2MoEBlock + def from_pretrained_decorator( + orig_from_pretrained: Callable, + ) -> Callable: @functools.wraps(orig_from_pretrained) def archer_from_pretrained(cls, *args, **kwargs): # print("Creating model from scratch ...") - name_id_map_file = os.path.join(self.checkpoint, "name_id_map.json") + name_id_map_file = os.path.join( + self.checkpoint, "name_id_map.json" + ) model_name = args[0] # if "arctic" in model_name: @@ -402,11 +429,15 @@ def archer_from_pretrained(cls, *args, **kwargs): empty_state_dict = {} self.name_id_map = {} for ckpt in tqdm( - self.ckpt_files, desc="Loading checkpoint files", smoothing=0 + self.ckpt_files, + desc="Loading checkpoint files", + smoothing=0, ): state_dict = {} if "safetensors" in ckpt: - with safe_open(ckpt, framework="pt", device="cpu") as f: + with safe_open( + ckpt, framework="pt", device="cpu" + ) as f: for k in f.keys(): state_dict[k] = f.get_tensor(k) else: @@ -441,14 +472,21 @@ def archer_from_pretrained(cls, *args, **kwargs): total=max_tensor_id, desc="Model create" ) - is_flash_attn_available = kwargs.get("is_flash_attn_available", False) + is_flash_attn_available = kwargs.get( + "is_flash_attn_available", False + ) # self.archer_prefetch.n_layer, self.archer_prefetch.n_expert, n_encoder_layers = parse_moe_param(self.config) - if self.dtype_cls is torch.bfloat16 or self.dtype_cls is torch.float16: + if ( + self.dtype_cls is torch.bfloat16 + or self.dtype_cls is torch.float16 + ): model = cls._from_config( self.config, torch_dtype=self.dtype_cls, attn_implementation=( - "flash_attention_2" if is_flash_attn_available else "eager" + "flash_attention_2" + if is_flash_attn_available + else "eager" ), ) else: @@ -464,12 +502,15 @@ def archer_from_pretrained(cls, *args, **kwargs): # print(self.config, flush=True) if hasattr(self.config, "quantization_config"): - self.quant_method = self.config.quantization_config["quant_method"] + self.quant_method = self.config.quantization_config[ + "quant_method" + ] self.config.quantization_config["use_exllama"] = False self.config.quantization_config["disable_exllama"] = True # print("Quantizing model ...", self.quant_method, flush=True) if self.quant_method == "gptq": from optimum.gptq import GPTQQuantizer + # print("Quantizing model with GPTQ ...", self.config.quantization_config, flush=True) optimum_quantizer = GPTQQuantizer.from_dict( self.config.quantization_config @@ -489,7 +530,9 @@ def archer_from_pretrained(cls, *args, **kwargs): for name, param in model.named_parameters(recurse=True): # remove base_model_prefix from self.name_id_map if name.startswith(base_model_prefix): - name_without_prefix = name[(len(base_model_prefix) + 1) :] + name_without_prefix = name[ + (len(base_model_prefix) + 1) : + ] if name_without_prefix in self.name_id_map: self.name_id_map[name] = self.name_id_map[ name_without_prefix @@ -498,8 +541,10 @@ def archer_from_pretrained(cls, *args, **kwargs): param.ar_id = self.name_id_map.get(name, None) # the case for NLLB MoE - if not "lm_head.weight" in self.name_id_map: - print("lm_head.weight not in name_id_map, add it as embed_tokens") + if "lm_head.weight" not in self.name_id_map: + print( + "lm_head.weight not in name_id_map, add it as embed_tokens" + ) self.name_id_map["lm_head.weight"] = 0 self.name_id_map["encoder.embed_tokens.weight"] = 0 self.name_id_map["decoder.embed_tokens.weight"] = 0 @@ -513,9 +558,18 @@ def archer_from_pretrained(cls, *args, **kwargs): layer_id, expert_id = parse_expert_id(name, self.config) if expert_id is not None: self.expert_tensor_map[(layer_id, expert_id)] = id + # print("expert_tensor_map", self.expert_tensor_map, flush=True) + self.expert_prefetcher.expert_tensor_map = ( + self.expert_tensor_map + ) - self.expert_prefetcher.expert_tensor_map = self.expert_tensor_map - + # for deepseek, we need to set the expert_tensor_map for the model + first_k_dense_replace = 0 + if "deepseek" in model_name: + self.expert_prefetcher.first_k_dense_replace = ( + self.config.first_k_dense_replace + ) + first_k_dense_replace = self.config.first_k_dense_replace # extracted_experts = [] # for param_name, tensor_id in self.name_id_map.items(): # # extract encoder, digits from "encoder.layers.7.ffn.experts.expert_78.fc1.weight" @@ -548,12 +602,13 @@ def archer_from_pretrained(cls, *args, **kwargs): # # make unique and sort # layer_idx = sorted(list(set(layer_idx))) - self.expert_executor.set_expert_dispatcher(self.expert_dispatcher) + self.expert_executor.set_expert_dispatcher( + self.expert_dispatcher + ) module_idx = 0 self.expert_layer_modules = [] for module in model.modules(): - if ( isinstance(module, SyncNllbMoeSparseMLP) or isinstance(module, SyncSwitchTransformersSparseMLP) @@ -561,6 +616,7 @@ def archer_from_pretrained(cls, *args, **kwargs): or isinstance(module, SyncMixtralSparseMoeBlock) or isinstance(module, SyncGrokMoeBlock) or isinstance(module, SyncArcticMoeBlock) + or isinstance(module, DeepseekV2MoEBlock) ): # module.archer_prefetch = self.archer_prefetch # module.archer_tracer = self.archer_tracer @@ -594,7 +650,7 @@ def archer_from_pretrained(cls, *args, **kwargs): # # self.archer_prefetch.extracted_experts[module_idx] = [ # # x[1] for x in expert_tensor_ids # # ] - module.layer_id = module_idx + module.layer_id = module_idx + first_k_dense_replace module_idx += 1 @@ -613,17 +669,18 @@ def archer_from_pretrained(cls, *args, **kwargs): # clean up initialization hooks def __exit__(self, exc_type, exc_value, traceback): - # GPTQ Override QuantLinear.__init__ = QuantLinear._old_init QuantLinearOld.__init__ = QuantLinearOld._old_init - + self.cls.__init__ = self.cls._old_init self.cls.from_pretrained = self.cls._old_from_pretrained - torch.nn.modules.module.Module.apply = torch.nn.modules.module.Module._old_apply + torch.nn.modules.module.Module.apply = ( + torch.nn.modules.module.Module._old_apply + ) torch.index_select = torch._old_index_select torch.Tensor.index_select = torch.Tensor._old_index_select - + self.cls.post_init = self.cls._old_post_init PreTrainedModel.post_init = PreTrainedModel._old_post_init @@ -659,7 +716,7 @@ def get_topology(self, model): print("param not in self.name_id_map", name) continue if match: - if "expert" in name: + if "expert" in name and "shared_experts" not in name: match = re.match(r"(.*experts)", name) assert match, "Not correct expert name!" stored_name = match.group(1) @@ -676,7 +733,9 @@ def get_topology(self, model): self.name_id_map[name] ] else: - ret_dict[stored_name] = {expert_name: [self.name_id_map[name]]} + ret_dict[stored_name] = { + expert_name: [self.name_id_map[name]] + } name_lst.append(stored_name) else: @@ -706,7 +765,7 @@ def get_topology(self, model): # print("buffer not in self.name_id_map", name) continue if match: - if "expert" in name: + if "expert" in name and "shared_experts" not in name: match = re.match(r"(.*experts)", name) assert match, "Not correct expert name!" stored_name = match.group(1) @@ -723,12 +782,16 @@ def get_topology(self, model): self.name_id_map[name] ] else: - ret_dict[stored_name] = {expert_name: [self.name_id_map[name]]} + ret_dict[stored_name] = { + expert_name: [self.name_id_map[name]] + } name_lst.append(stored_name) else: matches = [match for match in re.finditer(r"\d", name)] - last_number_position = matches[-1].start() if matches else -1 + last_number_position = ( + matches[-1].start() if matches else -1 + ) stored_name = name[: last_number_position + 1] if stored_name in name_lst: @@ -790,8 +853,9 @@ def _post_forward_output_hook(module, input, output, device, tensors): new_args = output.to(device) return new_args - def gen_args_hook(key, input_device_index, output_device_index, tensors): - + def gen_args_hook( + key, input_device_index, output_device_index, tensors + ): keys = key.split(".") # print(keys) m = model @@ -803,7 +867,9 @@ def gen_args_hook(key, input_device_index, output_device_index, tensors): m.register_forward_pre_hook( functools.partial( - _pre_forward_input_hook, device=input_device_index, tensors=tensors + _pre_forward_input_hook, + device=input_device_index, + tensors=tensors, ), prepend=True, with_kwargs=True, @@ -828,11 +894,16 @@ def gen_args_hook(key, input_device_index, output_device_index, tensors): for expert_idx, expert_tensors in enumerate(tensors): expert_key = ( f"{key}.expert_{expert_idx}" - if self.config.model_type != "mixtral" and self.config.model_type != "grok-1" and self.config.model_type != "arctic" + if self.config.model_type != "mixtral" + and self.config.model_type != "grok-1" + and self.config.model_type != "arctic" + and self.config.model_type != "deepseek_v2" else f"{key}.{expert_idx}" ) - input_device_index = self.archer_engine.get_node_default_device( - expert_tensors + input_device_index = ( + self.archer_engine.get_node_default_device( + expert_tensors + ) ) gen_args_hook( expert_key, @@ -849,7 +920,9 @@ def gen_args_hook(key, input_device_index, output_device_index, tensors): input_device_index = self.archer_engine.get_node_default_device( tensors[0] ) - gen_args_hook(key, input_device_index, output_device_index, tensors[0]) + gen_args_hook( + key, input_device_index, output_device_index, tensors[0] + ) output_device_index = input_device_index # @torch.no_grad() @@ -882,7 +955,9 @@ def _offload_state_dict( for param_name in param_names: self.name_id_map[param_name] = self._generate_param_id() - if not self.archer_engine.is_tensor_offloaded(self.name_id_map[param_name]): + if not self.archer_engine.is_tensor_offloaded( + self.name_id_map[param_name] + ): self.archer_engine.offload( state_dict[param_name], self.name_id_map[param_name] ) @@ -911,7 +986,7 @@ def _pre_forward_module_hook(module, args, kwargs): device_list = [] for name, param in module.named_parameters(recurse=False): - if not param.data.data_ptr() in self.offload_set: + if param.data.data_ptr() not in self.offload_set: num_devices = torch.cuda.device_count() param.data = param.data.to(f"cuda:{num_devices-1}") continue @@ -923,8 +998,7 @@ def _pre_forward_module_hook(module, args, kwargs): device_list.append(param.data.device) for name, buf in module.named_buffers(recurse=False): - - if not buf.data.data_ptr() in self.offload_set: + if buf.data.data_ptr() not in self.offload_set: buf.data = buf.data.to("cuda:0") continue @@ -942,8 +1016,7 @@ def _post_forward_module_hook(module, input, output): device_list = [] param_not_offload = set() for param in module.parameters(recurse=False): - - if not param.data.data_ptr() in self.offload_set: + if param.data.data_ptr() not in self.offload_set: param_not_offload.add(param.data.data_ptr()) continue @@ -954,8 +1027,7 @@ def _post_forward_module_hook(module, input, output): device_list.append(param.data.device) for buf in module.buffers(recurse=False): - - if not buf.data_ptr() in self.offload_set: + if buf.data_ptr() not in self.offload_set: continue self.offload_set.remove(buf.data_ptr()) @@ -972,7 +1044,9 @@ def _post_forward_module_hook(module, input, output): # Pre forward hook self.forward_hooks.append( - module.register_forward_pre_hook(_pre_forward_module_hook, with_kwargs=True) + module.register_forward_pre_hook( + _pre_forward_module_hook, with_kwargs=True + ) ) # Post forward hook @@ -983,18 +1057,16 @@ def _post_forward_module_hook(module, input, output): # clean runtime hooks def clean_up(self): transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._cast_classifier = transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router._old_cast_classifier - transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersSparseMLP = ( - transformers.models.switch_transformers.modeling_switch_transformers._old_sparse_mlp - ) - + transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersSparseMLP = transformers.models.switch_transformers.modeling_switch_transformers._old_sparse_mlp + transformers.models.nllb_moe.modeling_nllb_moe.NllbMoeSparseMLP = ( transformers.models.nllb_moe.modeling_nllb_moe._old_sparse_mlp ) - + transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock = ( transformers.models.mixtral.modeling_mixtral._old_sparse_mlp ) - + moe_infinity.models.modeling_grok.modeling_grok1.MoeBlock = ( moe_infinity.modeling_grok.modeling_grok1._old_sparse_mlp ) @@ -1002,5 +1074,5 @@ def clean_up(self): moe_infinity.models.modeling_arctic.modeling_arctic.ArcticMoE = ( moe_infinity.models.modeling_arctic._old_sparse_mlp ) - - \ No newline at end of file + + moe_infinity.models.modeling_deepseek.modeling_deepseek.DeepseekV2MoE = moe_infinity.models.modeling_deepseek._old_sparse_mlp diff --git a/moe_infinity/utils/__init__.py b/moe_infinity/utils/__init__.py index cdf1291..80c2c52 100644 --- a/moe_infinity/utils/__init__.py +++ b/moe_infinity/utils/__init__.py @@ -1,7 +1,7 @@ +from .checkpoints import get_checkpoint_paths +from .config import ArcherConfig from .hf_config import ( - parse_moe_param, - parse_expert_id, parse_expert_dtype, + parse_expert_id, + parse_moe_param, ) -from .config import ArcherConfig -from .checkpoints import get_checkpoint_paths diff --git a/moe_infinity/utils/arguments.py b/moe_infinity/utils/arguments.py index 0d31be2..0b54109 100644 --- a/moe_infinity/utils/arguments.py +++ b/moe_infinity/utils/arguments.py @@ -12,14 +12,14 @@ def copy_args_to_device(device, args): return args.to(device) for i in range(len(args)): if isinstance(args[i], torch.Tensor): - new_args += (args[i].to(device, non_blocking=True), ) + new_args += (args[i].to(device, non_blocking=True),) elif isinstance(args[i], list) or isinstance(args[i], tuple): # move_args_to_device(device, *args[i]) - new_args += (copy_args_to_device(device, args[i]), ) + new_args += (copy_args_to_device(device, args[i]),) elif isinstance(args[i], dict): - new_args += (copy_kwargs_to_device(device, args[i]), ) + new_args += (copy_kwargs_to_device(device, args[i]),) else: - new_args += (args[i], ) + new_args += (args[i],) # print("new_args", device, new_args) return new_args diff --git a/moe_infinity/utils/checkpoints.py b/moe_infinity/utils/checkpoints.py index 4c4fe04..8c16a9d 100644 --- a/moe_infinity/utils/checkpoints.py +++ b/moe_infinity/utils/checkpoints.py @@ -12,22 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import gc -import inspect import json -import logging import os -import re -import shutil -import tempfile -from collections import OrderedDict, defaultdict -from typing import Dict, List, Optional, Tuple, Union +from typing import Union -import torch -import torch.nn as nn - -from accelerate.utils.constants import WEIGHTS_NAME, SAFE_WEIGHTS_NAME +from accelerate.utils.constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME def get_checkpoint_paths(checkpoint: Union[str, os.PathLike]): @@ -47,15 +36,25 @@ def get_checkpoint_paths(checkpoint: Union[str, os.PathLike]): checkpoint_files = [checkpoint] elif os.path.isdir(checkpoint): # check if the whole state dict is present - potential_state_bin = [f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME] - potential_state_safetensor = [f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME] + potential_state_bin = [ + f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME + ] + potential_state_safetensor = [ + f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME + ] if len(potential_state_bin) == 1: - checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])] + checkpoint_files = [ + os.path.join(checkpoint, potential_state_bin[0]) + ] elif len(potential_state_safetensor) == 1: - checkpoint_files = [os.path.join(checkpoint, potential_state_safetensor[0])] + checkpoint_files = [ + os.path.join(checkpoint, potential_state_safetensor[0]) + ] else: # otherwise check for sharded checkpoints - potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")] + potential_index = [ + f for f in os.listdir(checkpoint) if f.endswith(".index.json") + ] if len(potential_index) == 0: raise ValueError( f"{checkpoint} is not a folder containing a `.index.json` file or a {WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file" @@ -80,6 +79,8 @@ def get_checkpoint_paths(checkpoint: Union[str, os.PathLike]): if "weight_map" in index: index = index["weight_map"] checkpoint_files = sorted(list(set(index.values()))) - checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files] + checkpoint_files = [ + os.path.join(checkpoint_folder, f) for f in checkpoint_files + ] return checkpoint_files diff --git a/moe_infinity/utils/config.py b/moe_infinity/utils/config.py index c67a6ee..86c0515 100644 --- a/moe_infinity/utils/config.py +++ b/moe_infinity/utils/config.py @@ -3,10 +3,11 @@ # TorchMoE Team -from dataclasses import dataclass, field import os -from transformers import HfArgumentParser +from dataclasses import dataclass, field + import torch +from transformers import HfArgumentParser @dataclass @@ -54,7 +55,9 @@ def load_from_json(self, config_json): return self def __post_init__(self): - self.perfect_cache_file = os.path.join(self.offload_path, "perfect_cache") + self.perfect_cache_file = os.path.join( + self.offload_path, "perfect_cache" + ) self.device_per_node = ( torch.cuda.device_count() diff --git a/moe_infinity/utils/hf_config.py b/moe_infinity/utils/hf_config.py index e7ef09d..56d4c2c 100644 --- a/moe_infinity/utils/hf_config.py +++ b/moe_infinity/utils/hf_config.py @@ -1,9 +1,10 @@ -from transformers import PretrainedConfig -from typing import Tuple import re +from typing import Tuple + import torch from transformers import PretrainedConfig + def parse_expert_dtype(config: PretrainedConfig) -> int: dtype = config.torch_dtype if dtype == torch.bfloat16: @@ -17,6 +18,7 @@ def parse_expert_dtype(config: PretrainedConfig) -> int: return dtype + def parse_moe_param(config: PretrainedConfig) -> Tuple[int, int, int]: arch = config.architectures[0].lower() @@ -40,13 +42,20 @@ def parse_moe_param(config: PretrainedConfig) -> Tuple[int, int, int]: num_decoder_layers = config.num_hidden_layers num_layers = config.num_hidden_layers num_experts = config.num_experts + elif "deepseek" in arch: + num_encoder_layers = 0 + num_decoder_layers = config.num_hidden_layers + num_layers = config.num_hidden_layers + num_experts = config.n_routed_experts else: raise RuntimeError(f"Unsupported architecture {arch}") return num_layers, num_experts, num_encoder_layers -def parse_expert_id(param_name: str, config: PretrainedConfig) -> Tuple[int, int]: +def parse_expert_id( + param_name: str, config: PretrainedConfig +) -> Tuple[int, int]: arch = config.architectures[0].lower() _, _, num_encoder_layers = parse_moe_param(config) @@ -91,6 +100,18 @@ def parse_expert_id(param_name: str, config: PretrainedConfig) -> Tuple[int, int # print(f"layer_id: {layer_id}, expert_id: {expert_id}") layer_id = int(layer_id) expert_id = int(expert_id) + elif "deepseek" in arch: + encoder_sparse_step = None + decoder_sparse_step = 1 + layer_type = "decoder" + + # example "model.layers.1.mlp.experts.0.gate_proj.weight" + result = re.findall(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.", param_name) + if result: + layer_id, expert_id = result[0] + # print(f"layer_id: {layer_id}, expert_id: {expert_id}") + layer_id = int(layer_id) + expert_id = int(expert_id) if result: if layer_type == "decoder": diff --git a/op_builder/__init__.py b/op_builder/__init__.py index bc59ba4..0835a37 100644 --- a/op_builder/__init__.py +++ b/op_builder/__init__.py @@ -12,12 +12,12 @@ # MoE-Infinity: replced builder_closure with PrefetchBuilder -import sys +import importlib import os import pkgutil -import importlib +import sys -from .builder import get_default_compute_capabilities, OpBuilder +from .builder import OpBuilder, get_default_compute_capabilities from .prefetch import PrefetchBuilder # Do not remove, required for abstract accelerator to detect if we have a deepspeed or 3p op_builder @@ -26,6 +26,7 @@ # List of all available op builders from deepspeed op_builder try: import moe_infinity.ops.op_builder # noqa: F401 + op_builder_dir = "moe_infinity.ops.op_builder" except ImportError: op_builder_dir = "op_builder" @@ -39,15 +40,18 @@ def builder_closure(member_name): # reflect builder names and add builder closure, such as 'TransformerBuilder()' creates op builder wrt current accelerator for _, module_name, _ in pkgutil.iter_modules( - [os.path.dirname(this_module.__file__)]): - if module_name != 'all_ops' and module_name != 'builder': - module = importlib.import_module(f".{module_name}", - package=op_builder_dir) + [os.path.dirname(this_module.__file__)] +): + if module_name != "all_ops" and module_name != "builder": + module = importlib.import_module( + f".{module_name}", package=op_builder_dir + ) for member_name in module.__dir__(): - if member_name.endswith( - 'Builder' - ) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder": + if ( + member_name.endswith("Builder") + and member_name != "OpBuilder" + and member_name != "CUDAOpBuilder" + ): # assign builder name to variable with same name # the following is equivalent to i.e. TransformerBuilder = "TransformerBuilder" - this_module.__dict__[member_name] = builder_closure( - member_name) + this_module.__dict__[member_name] = builder_closure(member_name) diff --git a/op_builder/all_ops.py b/op_builder/all_ops.py index e30a16f..bb494ec 100644 --- a/op_builder/all_ops.py +++ b/op_builder/all_ops.py @@ -12,8 +12,8 @@ # MoE-Infinity: deleted accelerator check. -import os import importlib +import os import pkgutil __op_builders__ = [] @@ -22,15 +22,19 @@ op_builder_module = importlib.import_module(op_builder_dir) for _, module_name, _ in pkgutil.iter_modules( - [os.path.dirname(op_builder_module.__file__)]): + [os.path.dirname(op_builder_module.__file__)] +): # avoid self references - if module_name != 'all_ops' and module_name != 'builder': - module = importlib.import_module("{}.{}".format( - op_builder_dir, module_name)) + if module_name != "all_ops" and module_name != "builder": + module = importlib.import_module( + "{}.{}".format(op_builder_dir, module_name) + ) for member_name in module.__dir__(): - if member_name.endswith( - 'Builder' - ) and member_name != "OpBuilder" and member_name != "CUDAOpBuilder": + if ( + member_name.endswith("Builder") + and member_name != "OpBuilder" + and member_name != "CUDAOpBuilder" + ): # append builder to __op_builders__ list builder = getattr(module, member_name)() __op_builders__.append(builder) diff --git a/op_builder/builder.py b/op_builder/builder.py index 0221f73..01f2262 100644 --- a/op_builder/builder.py +++ b/op_builder/builder.py @@ -10,23 +10,23 @@ # See https://github.com/microsoft/DeepSpeed/blob/master/LICENSE for license information. # SPDX-License-Identifier: Apache-2.0 +import distutils.ccompiler +import distutils.log +import distutils.sysconfig import os -import sys -import time -from pathlib import Path -import subprocess import shlex import shutil +import subprocess +import sys import tempfile -import distutils.ccompiler -import distutils.log -import distutils.sysconfig -from distutils.errors import CompileError, LinkError +import time from abc import ABC, abstractmethod +from distutils.errors import CompileError, LinkError +from pathlib import Path from typing import List -YELLOW = '\033[93m' -END = '\033[0m' +YELLOW = "\033[93m" +END = "\033[0m" WARNING = f"{YELLOW} [WARNING] {END}" DEFAULT_TORCH_EXTENSION_PATH = "/tmp/torch_extensions" @@ -39,20 +39,24 @@ f"{WARNING} unable to import torch, please install it if you want to pre-compile any deepspeed ops." ) else: - TORCH_MAJOR = int(torch.__version__.split('.')[0]) - TORCH_MINOR = int(torch.__version__.split('.')[1]) + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) def installed_cuda_version(name=""): import torch.utils.cpp_extension + cuda_home = torch.utils.cpp_extension.CUDA_HOME - assert cuda_home is not None, "CUDA_HOME does not exist, unable to compile CUDA op(s)" + assert ( + cuda_home is not None + ), "CUDA_HOME does not exist, unable to compile CUDA op(s)" # Ensure there is not a cuda version mismatch between torch and nvcc compiler - output = subprocess.check_output([cuda_home + "/bin/nvcc", "-V"], - universal_newlines=True) + output = subprocess.check_output( + [cuda_home + "/bin/nvcc", "-V"], universal_newlines=True + ) output_split = output.split() release_idx = output_split.index("release") - release = output_split[release_idx + 1].replace(',', '').split(".") + release = output_split[release_idx + 1].replace(",", "").split(".") # Ignore patch versions, only look at major + minor cuda_major, cuda_minor = release[:2] return int(cuda_major), int(cuda_minor) @@ -61,10 +65,15 @@ def installed_cuda_version(name=""): def get_default_compute_capabilities(): compute_caps = DEFAULT_COMPUTE_CAPABILITIES import torch.utils.cpp_extension - if torch.utils.cpp_extension.CUDA_HOME is not None and installed_cuda_version( - )[0] >= 11: - if installed_cuda_version()[0] == 11 and installed_cuda_version( - )[1] == 0: + + if ( + torch.utils.cpp_extension.CUDA_HOME is not None + and installed_cuda_version()[0] >= 11 + ): + if ( + installed_cuda_version()[0] == 11 + and installed_cuda_version()[1] == 0 + ): # Special treatment of CUDA 11.0 because compute_86 is not supported. compute_caps += ";8.0" else: @@ -80,21 +89,32 @@ def get_default_compute_capabilities(): "10.1", "10.2", ], - 11: - ["11.0", "11.1", "11.2", "11.3", "11.4", "11.5", "11.6", "11.7", "11.8"], + 11: [ + "11.0", + "11.1", + "11.2", + "11.3", + "11.4", + "11.5", + "11.6", + "11.7", + "11.8", + ], 12: ["12.0", "12.1"], } def assert_no_cuda_mismatch(name=""): cuda_major, cuda_minor = installed_cuda_version(name) - sys_cuda_version = f'{cuda_major}.{cuda_minor}' - torch_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) + sys_cuda_version = f"{cuda_major}.{cuda_minor}" + torch_cuda_version = ".".join(torch.version.cuda.split(".")[:2]) # This is a show-stopping error, should probably not proceed past this if sys_cuda_version != torch_cuda_version: - if (cuda_major in cuda_minor_mismatch_ok - and sys_cuda_version in cuda_minor_mismatch_ok[cuda_major] - and torch_cuda_version in cuda_minor_mismatch_ok[cuda_major]): + if ( + cuda_major in cuda_minor_mismatch_ok + and sys_cuda_version in cuda_minor_mismatch_ok[cuda_major] + and torch_cuda_version in cuda_minor_mismatch_ok[cuda_major] + ): print( f"Installed CUDA version {sys_cuda_version} does not match the " f"version torch was compiled with {torch.version.cuda} " @@ -111,7 +131,8 @@ def assert_no_cuda_mismatch(name=""): raise Exception( f">- DeepSpeed Op Builder: Installed CUDA version {sys_cuda_version} does not match the " f"version torch was compiled with {torch.version.cuda}, unable to compile " - "cuda/cpp extensions without a matching cuda version.") + "cuda/cpp extensions without a matching cuda version." + ) return True @@ -127,17 +148,17 @@ def __init__(self, name): @abstractmethod def absolute_name(self): - ''' + """ Returns absolute build path for cases where the op is pre-installed, e.g., deepspeed.ops.adam.cpu_adam will be installed as something like: deepspeed/ops/adam/cpu_adam.so - ''' + """ pass @abstractmethod def sources(self): - ''' + """ Returns list of source files for your op, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed) - ''' + """ pass def hipify_extension(self): @@ -145,38 +166,41 @@ def hipify_extension(self): @staticmethod def validate_torch_version(torch_info): - install_torch_version = torch_info['version'] - current_torch_version = ".".join(torch.__version__.split('.')[:2]) + install_torch_version = torch_info["version"] + current_torch_version = ".".join(torch.__version__.split(".")[:2]) if install_torch_version != current_torch_version: raise RuntimeError( "PyTorch version mismatch! DeepSpeed ops were compiled and installed " "with a different version than what is being used at runtime. " f"Please re-install DeepSpeed or switch torch versions. " f"Install torch version={install_torch_version}, " - f"Runtime torch version={current_torch_version}") + f"Runtime torch version={current_torch_version}" + ) @staticmethod def validate_torch_op_version(torch_info): if not OpBuilder.is_rocm_pytorch(): - current_cuda_version = ".".join(torch.version.cuda.split('.')[:2]) - install_cuda_version = torch_info['cuda_version'] + current_cuda_version = ".".join(torch.version.cuda.split(".")[:2]) + install_cuda_version = torch_info["cuda_version"] if install_cuda_version != current_cuda_version: raise RuntimeError( "CUDA version mismatch! DeepSpeed ops were compiled and installed " "with a different version than what is being used at runtime. " f"Please re-install DeepSpeed or switch torch versions. " f"Install CUDA version={install_cuda_version}, " - f"Runtime CUDA version={current_cuda_version}") + f"Runtime CUDA version={current_cuda_version}" + ) else: - current_hip_version = ".".join(torch.version.hip.split('.')[:2]) - install_hip_version = torch_info['hip_version'] + current_hip_version = ".".join(torch.version.hip.split(".")[:2]) + install_hip_version = torch_info["hip_version"] if install_hip_version != current_hip_version: raise RuntimeError( "HIP version mismatch! DeepSpeed ops were compiled and installed " "with a different version than what is being used at runtime. " f"Please re-install DeepSpeed or switch torch versions. " f"Install HIP version={install_hip_version}, " - f"Runtime HIP version={current_hip_version}") + f"Runtime HIP version={current_hip_version}" + ) @staticmethod def is_rocm_pytorch(): @@ -190,10 +214,13 @@ def is_rocm_pytorch(): pass else: if TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 5): - _is_rocm_pytorch = hasattr( - torch.version, 'hip') and torch.version.hip is not None + _is_rocm_pytorch = ( + hasattr(torch.version, "hip") + and torch.version.hip is not None + ) if _is_rocm_pytorch: from torch.utils.cpp_extension import ROCM_HOME + _is_rocm_pytorch = ROCM_HOME is not None OpBuilder._is_rocm_pytorch = _is_rocm_pytorch return OpBuilder._is_rocm_pytorch @@ -203,46 +230,47 @@ def installed_rocm_version(): if OpBuilder._rocm_version: return OpBuilder._rocm_version - ROCM_MAJOR = '0' - ROCM_MINOR = '0' + ROCM_MAJOR = "0" + ROCM_MINOR = "0" if OpBuilder.is_rocm_pytorch(): from torch.utils.cpp_extension import ROCM_HOME + rocm_ver_file = Path(ROCM_HOME).joinpath(".info/version-dev") if rocm_ver_file.is_file(): - with open(rocm_ver_file, 'r') as file: + with open(rocm_ver_file, "r") as file: ROCM_VERSION_DEV_RAW = file.read() elif "rocm" in torch.__version__: ROCM_VERSION_DEV_RAW = torch.__version__.split("rocm")[1] else: assert False, "Could not detect ROCm version" assert ROCM_VERSION_DEV_RAW != "", "Could not detect ROCm version" - ROCM_MAJOR = ROCM_VERSION_DEV_RAW.split('.')[0] - ROCM_MINOR = ROCM_VERSION_DEV_RAW.split('.')[1] + ROCM_MAJOR = ROCM_VERSION_DEV_RAW.split(".")[0] + ROCM_MINOR = ROCM_VERSION_DEV_RAW.split(".")[1] OpBuilder._rocm_version = (int(ROCM_MAJOR), int(ROCM_MINOR)) return OpBuilder._rocm_version def include_paths(self): - ''' + """ Returns list of include paths, relative to root of deepspeed package (i.e., DeepSpeed/deepspeed) - ''' + """ return [] def nvcc_args(self): - ''' + """ Returns optional list of compiler flags to forward to nvcc when building CUDA sources - ''' + """ return [] def cxx_args(self): - ''' + """ Returns optional list of compiler flags to forward to the build - ''' + """ return [] def is_compatible(self, verbose=True): - ''' + """ Check if all non-python dependencies are satisfied to build this op - ''' + """ return True def extra_ldflags(self): @@ -250,24 +278,26 @@ def extra_ldflags(self): def libraries_installed(self, libraries): valid = False - check_cmd = 'dpkg -l' + check_cmd = "dpkg -l" for lib in libraries: - result = subprocess.Popen(f'dpkg -l {lib}', - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - shell=True) + result = subprocess.Popen( + f"dpkg -l {lib}", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + ) valid = valid or result.wait() == 0 return valid def has_function(self, funcname, libraries, verbose=False): - ''' + """ Test for existence of a function within a tuple of libraries. This is used as a smoke test to check whether a certain library is available. As a test, this creates a simple C program that calls the specified function, and then distutils is used to compile that program and link it with the specified libraries. Returns True if both the compile and link are successful, False otherwise. - ''' + """ tempdir = None # we create a temporary directory to hold various files filestderr = None # handle to open file to which we redirect stderr oldstderr = None # file descriptor for stderr @@ -286,42 +316,49 @@ def has_function(self, funcname, libraries, verbose=False): tempdir = tempfile.mkdtemp() # Define a simple C program that calls the function in question - prog = "void %s(void); int main(int argc, char** argv) { %s(); return 0; }" % ( - funcname, funcname) + prog = ( + "void %s(void); int main(int argc, char** argv) { %s(); return 0; }" + % (funcname, funcname) + ) # Write the test program to a file. - filename = os.path.join(tempdir, 'test.c') - with open(filename, 'w') as f: + filename = os.path.join(tempdir, "test.c") + with open(filename, "w") as f: f.write(prog) # Redirect stderr file descriptor to a file to silence compile/link warnings. if not verbose: - filestderr = open(os.path.join(tempdir, 'stderr.txt'), 'w') + filestderr = open(os.path.join(tempdir, "stderr.txt"), "w") oldstderr = os.dup(sys.stderr.fileno()) os.dup2(filestderr.fileno(), sys.stderr.fileno()) # Workaround for behavior in distutils.ccompiler.CCompiler.object_filenames() # Otherwise, a local directory will be used instead of tempdir drive, driveless_filename = os.path.splitdrive(filename) - root_dir = driveless_filename[0] if os.path.isabs( - driveless_filename) else '' + root_dir = ( + driveless_filename[0] + if os.path.isabs(driveless_filename) + else "" + ) output_dir = os.path.join(drive, root_dir) # Attempt to compile the C program into an object file. - cflags = shlex.split(os.environ.get('CFLAGS', "")) + cflags = shlex.split(os.environ.get("CFLAGS", "")) objs = compiler.compile( [filename], output_dir=output_dir, - extra_preargs=self.strip_empty_entries(cflags)) + extra_preargs=self.strip_empty_entries(cflags), + ) # Attempt to link the object file into an executable. # Be sure to tack on any libraries that have been specified. - ldflags = shlex.split(os.environ.get('LDFLAGS', "")) + ldflags = shlex.split(os.environ.get("LDFLAGS", "")) compiler.link_executable( objs, - os.path.join(tempdir, 'a.out'), + os.path.join(tempdir, "a.out"), extra_preargs=self.strip_empty_entries(ldflags), - libraries=libraries) + libraries=libraries, + ) # Compile and link succeeded return True @@ -347,15 +384,15 @@ def has_function(self, funcname, libraries, verbose=False): shutil.rmtree(tempdir) def strip_empty_entries(self, args): - ''' + """ Drop any empty strings from the list of compile and link flags - ''' + """ return [x for x in args if len(x) > 0] def cpu_arch(self): try: from cpuinfo import get_cpu_info - except ImportError as e: + except ImportError: cpu_info = self._backup_cpuinfo() if cpu_info is None: return "-march=native" @@ -365,90 +402,93 @@ def cpu_arch(self): except Exception as e: self.warning( f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), " - "falling back to `lscpu` to get this information.") + "falling back to `lscpu` to get this information." + ) cpu_info = self._backup_cpuinfo() if cpu_info is None: return "-march=native" - if cpu_info['arch'].startswith('PPC_'): + if cpu_info["arch"].startswith("PPC_"): # gcc does not provide -march on PowerPC, use -mcpu instead - return '-mcpu=native' - return '-march=native' + return "-mcpu=native" + return "-march=native" def is_cuda_enable(self): try: assert_no_cuda_mismatch(self.name) - return '-D__ENABLE_CUDA__' + return "-D__ENABLE_CUDA__" except BaseException: print( f"{WARNING} {self.name} cuda is missing or is incompatible with installed torch, " - "only cpu ops can be compiled!") - return '-D__DISABLE_CUDA__' - return '-D__DISABLE_CUDA__' + "only cpu ops can be compiled!" + ) + return "-D__DISABLE_CUDA__" + return "-D__DISABLE_CUDA__" def _backup_cpuinfo(self): # Construct cpu_info dict from lscpu that is similar to what py-cpuinfo provides - if not self.command_exists('lscpu'): + if not self.command_exists("lscpu"): self.warning( f"{self.name} attempted to query 'lscpu' after failing to use py-cpuinfo " "to detect the CPU architecture. 'lscpu' does not appear to exist on " "your system, will fall back to use -march=native and non-vectorized execution." ) return None - result = subprocess.check_output('lscpu', shell=True) - result = result.decode('utf-8').strip().lower() + result = subprocess.check_output("lscpu", shell=True) + result = result.decode("utf-8").strip().lower() cpu_info = {} - cpu_info['arch'] = None - cpu_info['flags'] = "" - if 'genuineintel' in result or 'authenticamd' in result: - cpu_info['arch'] = 'X86_64' - if 'avx512' in result: - cpu_info['flags'] += 'avx512,' - elif 'avx512f' in result: - cpu_info['flags'] += 'avx512f,' - if 'avx2' in result: - cpu_info['flags'] += 'avx2' - elif 'ppc64le' in result: - cpu_info['arch'] = "PPC_" + cpu_info["arch"] = None + cpu_info["flags"] = "" + if "genuineintel" in result or "authenticamd" in result: + cpu_info["arch"] = "X86_64" + if "avx512" in result: + cpu_info["flags"] += "avx512," + elif "avx512f" in result: + cpu_info["flags"] += "avx512f," + if "avx2" in result: + cpu_info["flags"] += "avx2" + elif "ppc64le" in result: + cpu_info["arch"] = "PPC_" return cpu_info def simd_width(self): try: from cpuinfo import get_cpu_info - except ImportError as e: + except ImportError: cpu_info = self._backup_cpuinfo() if cpu_info is None: - return '-D__SCALAR__' + return "-D__SCALAR__" try: cpu_info = get_cpu_info() except Exception as e: self.warning( f"{self.name} attempted to use `py-cpuinfo` but failed (exception type: {type(e)}, {e}), " - "falling back to `lscpu` to get this information.") + "falling back to `lscpu` to get this information." + ) cpu_info = self._backup_cpuinfo() if cpu_info is None: - return '-D__SCALAR__' + return "-D__SCALAR__" - if cpu_info['arch'] == 'X86_64': - if 'avx512' in cpu_info['flags'] or 'avx512f' in cpu_info['flags']: - return '-D__AVX512__' - elif 'avx2' in cpu_info['flags']: - return '-D__AVX256__' - return '-D__SCALAR__' + if cpu_info["arch"] == "X86_64": + if "avx512" in cpu_info["flags"] or "avx512f" in cpu_info["flags"]: + return "-D__AVX512__" + elif "avx2" in cpu_info["flags"]: + return "-D__AVX256__" + return "-D__SCALAR__" def command_exists(self, cmd): - if '|' in cmd: + if "|" in cmd: cmds = cmd.split("|") else: cmds = [cmd] valid = False for cmd in cmds: - result = subprocess.Popen(f'type {cmd}', - stdout=subprocess.PIPE, - shell=True) + result = subprocess.Popen( + f"type {cmd}", stdout=subprocess.PIPE, shell=True + ) valid = valid or result.wait() == 0 if not valid and len(cmds) > 1: @@ -470,10 +510,12 @@ def deepspeed_src_path(self, code_path): return code_path else: return os.path.join( - Path(__file__).parent.parent.absolute(), code_path) + Path(__file__).parent.parent.absolute(), code_path + ) def builder(self): from torch.utils.cpp_extension import CppExtension + abs_include_paths = [ self.deepspeed_src_path(path) for path in self.include_paths() ] @@ -482,9 +524,10 @@ def builder(self): sources=self.strip_empty_entries(self.sources()), include_dirs=self.strip_empty_entries(abs_include_paths), extra_compile_args={ - 'cxx': self.strip_empty_entries(self.cxx_args()) + "cxx": self.strip_empty_entries(self.cxx_args()) }, - extra_link_args=self.strip_empty_entries(self.extra_ldflags())) + extra_link_args=self.strip_empty_entries(self.extra_ldflags()), + ) def load(self, verbose=True): return self.jit_load(verbose) @@ -541,7 +584,8 @@ def jit_load(self, verbose=True): extra_cflags=cxx_args, extra_cuda_cflags=nvcc_args, extra_ldflags=self.strip_empty_entries(self.extra_ldflags()), - verbose=verbose) + verbose=verbose, + ) build_duration = time.time() - start_build if verbose: @@ -555,7 +599,6 @@ def jit_load(self, verbose=True): class CUDAOpBuilder(OpBuilder): - def compute_capability_args(self, cross_compile_archs=None): """ Returns nvcc compute capability compile flags. @@ -583,22 +626,23 @@ def compute_capability_args(self, cross_compile_archs=None): if cc not in ccs: ccs.append(cc) ccs = sorted(ccs) - ccs[-1] += '+PTX' + ccs[-1] += "+PTX" else: # Cross-compile mode, compile for various architectures # env override takes priority - cross_compile_archs_env = os.environ.get('TORCH_CUDA_ARCH_LIST', - None) + cross_compile_archs_env = os.environ.get( + "TORCH_CUDA_ARCH_LIST", None + ) if cross_compile_archs_env is not None: if cross_compile_archs is not None: print( f"{WARNING} env var `TORCH_CUDA_ARCH_LIST={cross_compile_archs_env}` overrides `cross_compile_archs={cross_compile_archs}`" ) - cross_compile_archs = cross_compile_archs_env.replace(' ', ';') + cross_compile_archs = cross_compile_archs_env.replace(" ", ";") else: if cross_compile_archs is None: cross_compile_archs = get_default_compute_capabilities() - ccs = cross_compile_archs.split(';') + ccs = cross_compile_archs.split(";") ccs = self.filter_ccs(ccs) if len(ccs) == 0: @@ -610,9 +654,9 @@ def compute_capability_args(self, cross_compile_archs=None): self.enable_bf16 = True for cc in ccs: num = cc[0] + cc[2] - args.append(f'-gencode=arch=compute_{num},code=sm_{num}') - if cc.endswith('+PTX'): - args.append(f'-gencode=arch=compute_{num},code=compute_{num}') + args.append(f"-gencode=arch=compute_{num},code=sm_{num}") + if cc.endswith("+PTX"): + args.append(f"-gencode=arch=compute_{num},code=compute_{num}") if int(cc[0]) <= 7: self.enable_bf16 = False @@ -630,13 +674,13 @@ def version_dependent_macros(self): # Fix from apex that might be relevant for us as well, related to https://github.com/NVIDIA/apex/issues/456 version_ge_1_1 = [] if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 0): - version_ge_1_1 = ['-DVERSION_GE_1_1'] + version_ge_1_1 = ["-DVERSION_GE_1_1"] version_ge_1_3 = [] if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 2): - version_ge_1_3 = ['-DVERSION_GE_1_3'] + version_ge_1_3 = ["-DVERSION_GE_1_3"] version_ge_1_5 = [] if (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4): - version_ge_1_5 = ['-DVERSION_GE_1_5'] + version_ge_1_5 = ["-DVERSION_GE_1_5"] return version_ge_1_1 + version_ge_1_3 + version_ge_1_5 def is_compatible(self, verbose=True): @@ -650,23 +694,33 @@ def builder(self): self.build_for_cpu = True if self.build_for_cpu: - from torch.utils.cpp_extension import CppExtension as ExtensionBuilder + from torch.utils.cpp_extension import ( + CppExtension as ExtensionBuilder, + ) else: - from torch.utils.cpp_extension import CUDAExtension as ExtensionBuilder + from torch.utils.cpp_extension import ( + CUDAExtension as ExtensionBuilder, + ) - compile_args = {'cxx': self.strip_empty_entries(self.cxx_args())} if self.build_for_cpu else \ - {'cxx': self.strip_empty_entries(self.cxx_args()), \ - 'nvcc': self.strip_empty_entries(self.nvcc_args())} + compile_args = ( + {"cxx": self.strip_empty_entries(self.cxx_args())} + if self.build_for_cpu + else { + "cxx": self.strip_empty_entries(self.cxx_args()), + "nvcc": self.strip_empty_entries(self.nvcc_args()), + } + ) if not self.build_for_cpu and self.enable_bf16: - compile_args['cxx'].append("-DBF16_AVAILABLE") + compile_args["cxx"].append("-DBF16_AVAILABLE") cuda_ext = ExtensionBuilder( name=self.absolute_name(), sources=self.strip_empty_entries(self.sources()), include_dirs=self.strip_empty_entries(self.include_paths()), libraries=self.strip_empty_entries(self.libraries_args()), - extra_compile_args=compile_args) + extra_compile_args=compile_args, + ) if self.is_rocm_pytorch(): # hip converts paths to absolute, this converts back to relative @@ -684,11 +738,12 @@ def builder(self): def hipify_extension(self): if self.is_rocm_pytorch(): from torch.utils.hipify import hipify_python + hipify_python.hipify( project_directory=os.getcwd(), output_directory=os.getcwd(), header_include_dirs=self.include_paths(), - includes=[os.path.join(os.getcwd(), '*')], + includes=[os.path.join(os.getcwd(), "*")], extra_files=[os.path.abspath(s) for s in self.sources()], show_detailed=True, is_pytorch_extension=True, @@ -697,35 +752,40 @@ def hipify_extension(self): def cxx_args(self): if sys.platform == "win32": - return ['-O2'] + return ["-O2"] else: - return ['-O3', '-std=c++14', '-g', '-Wno-reorder'] + return ["-O3", "-std=c++14", "-g", "-Wno-reorder"] def nvcc_args(self): if self.build_for_cpu: return [] - args = ['-O3'] + args = ["-O3"] if self.is_rocm_pytorch(): ROCM_MAJOR, ROCM_MINOR = self.installed_rocm_version() args += [ - '-std=c++14', '-U__HIP_NO_HALF_OPERATORS__', - '-U__HIP_NO_HALF_CONVERSIONS__', - '-U__HIP_NO_HALF2_OPERATORS__', - '-DROCM_VERSION_MAJOR=%s' % ROCM_MAJOR, - '-DROCM_VERSION_MINOR=%s' % ROCM_MINOR + "-std=c++14", + "-U__HIP_NO_HALF_OPERATORS__", + "-U__HIP_NO_HALF_CONVERSIONS__", + "-U__HIP_NO_HALF2_OPERATORS__", + "-DROCM_VERSION_MAJOR=%s" % ROCM_MAJOR, + "-DROCM_VERSION_MINOR=%s" % ROCM_MINOR, ] else: cuda_major, _ = installed_cuda_version() args += [ - '-allow-unsupported-compiler' - if sys.platform == "win32" else '', '--use_fast_math', - '-std=c++17' if sys.platform == "win32" and cuda_major > 10 - else '-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', - '-U__CUDA_NO_HALF_CONVERSIONS__', - '-U__CUDA_NO_HALF2_OPERATORS__' + "-allow-unsupported-compiler" + if sys.platform == "win32" + else "", + "--use_fast_math", + "-std=c++17" + if sys.platform == "win32" and cuda_major > 10 + else "-std=c++14", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", ] - if os.environ.get('DS_DEBUG_CUDA_BUILD', '0') == '1': - args.append('--ptxas-options=-v') + if os.environ.get("DS_DEBUG_CUDA_BUILD", "0") == "1": + args.append("--ptxas-options=-v") args += self.compute_capability_args() return args @@ -734,39 +794,41 @@ def libraries_args(self): return [] if sys.platform == "win32": - return ['cublas', 'curand'] + return ["cublas", "curand"] else: return [] class TorchCPUOpBuilder(CUDAOpBuilder): - def extra_ldflags(self): if self.build_for_cpu: - return ['-fopenmp'] + return ["-fopenmp"] if not self.is_rocm_pytorch(): - return ['-lcurand'] + return ["-lcurand"] return [] def cxx_args(self): import torch + args = [] if not self.build_for_cpu: if not self.is_rocm_pytorch(): - CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.CUDA_HOME, - "lib64") + CUDA_LIB64 = os.path.join( + torch.utils.cpp_extension.CUDA_HOME, "lib64" + ) else: - CUDA_LIB64 = os.path.join(torch.utils.cpp_extension.ROCM_HOME, - "lib") + CUDA_LIB64 = os.path.join( + torch.utils.cpp_extension.ROCM_HOME, "lib" + ) args += super().cxx_args() args += [ - f'-L{CUDA_LIB64}', - '-lcudart', - '-lcublas', - '-g', + f"-L{CUDA_LIB64}", + "-lcudart", + "-lcublas", + "-g", ] CPU_ARCH = self.cpu_arch() @@ -774,7 +836,7 @@ def cxx_args(self): CUDA_ENABLE = self.is_cuda_enable() args += [ CPU_ARCH, - '-fopenmp', + "-fopenmp", SIMD_WIDTH, CUDA_ENABLE, ] diff --git a/op_builder/prefetch.py b/op_builder/prefetch.py index f4d320b..403329e 100644 --- a/op_builder/prefetch.py +++ b/op_builder/prefetch.py @@ -13,9 +13,6 @@ # MoE-Infinity: replaced AsyncIOBuilder with PrefetchBuilder from .builder import OpBuilder -import distutils -import subprocess -import glob class PrefetchBuilder(OpBuilder): @@ -24,56 +21,56 @@ class PrefetchBuilder(OpBuilder): def __init__(self): super().__init__(name=self.NAME) - + def absolute_name(self): - return f'moe_infinity.ops.prefetch.{self.NAME}_op' + return f"moe_infinity.ops.prefetch.{self.NAME}_op" def sources(self): return [ - 'core/utils/archer_logger.cpp', - 'core/utils/cuda_utils.cpp', - 'core/model/model_topology.cpp', - 'core/prefetch/archer_prefetch_handle.cpp', - 'core/prefetch/task_scheduler.cpp', - 'core/prefetch/task_thread.cpp', - 'core/memory/memory_pool.cpp', - 'core/memory/stream_pool.cpp', - 'core/memory/host_caching_allocator.cpp', - 'core/python/py_archer_prefetch.cpp', - 'core/parallel/expert_dispatcher.cpp', - 'core/parallel/expert_module.cpp', - 'core/aio/archer_aio_thread.cpp', - 'core/aio/archer_prio_aio_handle.cpp', - 'core/aio/archer_aio_utils.cpp', - 'core/aio/archer_aio_threadpool.cpp', - 'core/aio/archer_tensor_handle.cpp', - 'core/aio/archer_tensor_index.cpp', + "core/utils/archer_logger.cpp", + "core/utils/cuda_utils.cpp", + "core/model/model_topology.cpp", + "core/prefetch/archer_prefetch_handle.cpp", + "core/prefetch/task_scheduler.cpp", + "core/prefetch/task_thread.cpp", + "core/memory/memory_pool.cpp", + "core/memory/stream_pool.cpp", + "core/memory/host_caching_allocator.cpp", + "core/python/py_archer_prefetch.cpp", + "core/parallel/expert_dispatcher.cpp", + "core/parallel/expert_module.cpp", + "core/aio/archer_aio_thread.cpp", + "core/aio/archer_prio_aio_handle.cpp", + "core/aio/archer_aio_utils.cpp", + "core/aio/archer_aio_threadpool.cpp", + "core/aio/archer_tensor_handle.cpp", + "core/aio/archer_tensor_index.cpp", ] def include_paths(self): - return ['core'] + return ["core"] def cxx_args(self): # -O0 for improved debugging, since performance is bound by I/O CPU_ARCH = self.cpu_arch() SIMD_WIDTH = self.simd_width() return [ - '-g', - '-Wall', - '-O2', - '-std=c++17', - '-shared', - '-fPIC', - '-Wno-reorder', + "-g", + "-Wall", + "-O2", + "-std=c++17", + "-shared", + "-fPIC", + "-Wno-reorder", CPU_ARCH, - '-fopenmp', + "-fopenmp", SIMD_WIDTH, - '-I/usr/local/cuda/include', - '-L/usr/local/cuda/lib64', - '-lcuda', - '-lcudart', - '-lcublas', - '-lpthread', + "-I/usr/local/cuda/include", + "-L/usr/local/cuda/lib64", + "-lcuda", + "-lcudart", + "-lcublas", + "-lpthread", ] def extra_ldflags(self): diff --git a/pyproject.toml b/pyproject.toml index 07b58a6..4e83b0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,32 @@ [build-system] requires = ["setuptools", "wheel", "torch"] -build-backend = "setuptools.build_meta" \ No newline at end of file +build-backend = "setuptools.build_meta" + + +[tool.ruff] +line-length = 80 +exclude = [] + +[tool.ruff.lint] +fixable = ["ALL"] +unfixable = [ + # star imports + "F405", + "F403", + # lambda expression assignment + "E731", + # Loop control variable not used within loop body + "B007", + # raise distinguish errors + "B904", + # f-string format + "UP032", +] +select = [ + # isort + "I", +] +ignore = [ + # Loop control variable not used within loop body + "B007" +] diff --git a/requirements.txt b/requirements.txt index 12cf8bb..3c127d5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,16 +1,17 @@ +accelerate +auto_gptq +chardet +datasets>=2.12.0 hjson ninja +optimum>=1.17.1 packaging>=20.0 +pre-commit py-cpuinfo -torch>=2.1.1 -transformers>=4.37.1, <4.40 -sentencepiece -pydantic==1.10.12 -datasets>=2.12.0 pyarrow==12.0.0 -accelerate -sphinx +pydantic==1.10.12 scipy -chardet -optimum>=1.17.1 -auto_gptq +sentencepiece +sphinx +torch>=2.1.1 +transformers>=4.37.1, <4.40 diff --git a/setup.cfg b/setup.cfg index 42efc26..c70b9f7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,2 @@ [options.data_files] -. = requirements.txt \ No newline at end of file +. = requirements.txt diff --git a/setup.py b/setup.py index 936da6e..464e08a 100644 --- a/setup.py +++ b/setup.py @@ -4,43 +4,49 @@ # TorchMoE Team import io -from typing import List -from setuptools import setup, find_packages import os import sys +from setuptools import find_packages, setup + torch_available = True try: import torch # noqa: F401 except ImportError: torch_available = False - print('[WARNING] Unable to import torch, pre-compiling ops will be disabled. ' \ - 'Please visit https://pytorch.org/ to see how to properly install torch on your system.') + print( + "[WARNING] Unable to import torch, pre-compiling ops will be disabled. " + "Please visit https://pytorch.org/ to see how to properly install torch on your system." + ) ROOT_DIR = os.path.dirname(__file__) sys.path.insert(0, ROOT_DIR) # sys.path.insert(0, os.path.join(ROOT_DIR, 'src')) -from op_builder.all_ops import ALL_OPS from torch.utils import cpp_extension -RED_START = '\033[31m' -RED_END = '\033[0m' +from op_builder.all_ops import ALL_OPS + +RED_START = "\033[31m" +RED_END = "\033[0m" ERROR = f"{RED_START} [ERROR] {RED_END}" def fetch_requirements(path): - with open(path, 'r') as fd: + with open(path, "r") as fd: return [r.strip() for r in fd.readlines()] + def get_path(*filepath) -> str: return os.path.join(ROOT_DIR, *filepath) + def abort(msg): print(f"{ERROR} {msg}") assert False, msg + def read_readme() -> str: """Read the README file if present.""" p = get_path("README.md") @@ -49,14 +55,15 @@ def read_readme() -> str: else: return "" -install_requires = fetch_requirements('requirements.txt') + +install_requires = fetch_requirements("requirements.txt") ext_modules = [] -BUILD_OP_DEFAULT = int(os.environ.get('BUILD_OPS', 0)) +BUILD_OP_DEFAULT = int(os.environ.get("BUILD_OPS", 0)) if BUILD_OP_DEFAULT: - assert torch_available, 'Unable to pre-compile ops without torch installed. Please install torch before attempting to pre-compile ops.' + assert torch_available, "Unable to pre-compile ops without torch installed. Please install torch before attempting to pre-compile ops." compatible_ops = dict.fromkeys(ALL_OPS.keys(), False) install_ops = dict.fromkeys(ALL_OPS.keys(), False) for op_name, builder in ALL_OPS.items(): @@ -68,36 +75,38 @@ def read_readme() -> str: ext_modules.append(builder.builder()) cmdclass = { - 'build_ext': cpp_extension.BuildExtension.with_options(use_ninja=True) + "build_ext": cpp_extension.BuildExtension.with_options(use_ninja=True) } print(f"find_packages: {find_packages()}") -# install all files in the package, rather than just the egg +# install all files in the package, rather than just the egg setup( - name='moe_infinity', - version=os.getenv('MOEINF_VERSION', '0.0.1'), - packages=find_packages(exclude=['op_builder', 'op_builder.*', 'moe_infinity.ops.core.*']), + name="moe_infinity", + version=os.getenv("MOEINF_VERSION", "0.0.1"), + packages=find_packages( + exclude=["op_builder", "op_builder.*", "moe_infinity.ops.core.*"] + ), package_data={ - 'moe_infinity.ops.prefetch': ['**/*.so'], - 'moe_infinity': ['ops/core/**'] + "moe_infinity.ops.prefetch": ["**/*.so"], + "moe_infinity": ["ops/core/**"], }, include_package_data=True, install_requires=install_requires, - author='TorchMoE Team', + author="TorchMoE Team", long_description=read_readme(), long_description_content_type="text/markdown", url="https://github.com/TorchMoE/MoE-Infinity", - project_urls={'Homepage': 'https://github.com/TorchMoE/MoE-Infinity'}, + project_urls={"Homepage": "https://github.com/TorchMoE/MoE-Infinity"}, classifiers=[ - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', - 'Programming Language :: Python :: 3.10', - 'Programming Language :: Python :: 3.11', + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", "License :: OSI Approved :: Apache Software License", "Topic :: Scientific/Engineering :: Artificial Intelligence", ], - license='Apache License 2.0', + license="Apache License 2.0", python_requires=">=3.8", ext_modules=ext_modules, cmdclass=cmdclass, From b074ee7ff6c39654cbda20bee69e5171c922ad5e Mon Sep 17 00:00:00 2001 From: xly Date: Mon, 20 Jan 2025 15:14:04 +0000 Subject: [PATCH 4/4] working deepspeed backend --- core/parallel/expert_dispatcher.cpp | 29 +- core/parallel/expert_module.cpp | 49 ++- core/parallel/expert_module.h | 11 + examples/interface_example.py | 3 + moe_infinity/common/constants.py | 2 +- moe_infinity/distributed/expert_prefetcher.py | 1 - moe_infinity/memory/expert_prefetcher.py | 3 +- moe_infinity/models/deepseek.py | 125 ++++--- .../models/modeling_deepseek/__init__.py | 9 +- .../configuration_deepseek.py | 44 +-- .../modeling_deepseek/modeling_deepseek.py | 319 +++++++++++++----- .../tokenization_deepseek_fast.py | 4 +- moe_infinity/runtime/model_offload.py | 5 +- 13 files changed, 435 insertions(+), 169 deletions(-) diff --git a/core/parallel/expert_dispatcher.cpp b/core/parallel/expert_dispatcher.cpp index 5b0ed04..c3673e7 100644 --- a/core/parallel/expert_dispatcher.cpp +++ b/core/parallel/expert_dispatcher.cpp @@ -83,6 +83,9 @@ ExpertDispatcher::ExpertDispatcher(int num_experts, int num_layers, int dtype, i case MIXTRAL_MOE_DENSE_ACT_DENSE: experts_[i][j]->module = new MixtralMoEDenseActDense(dtype); break; + case DEEPSEEK_MOE_DENSE_ACT_DENSE: + experts_[i][j]->module = new DeepSeekMoEDenseActDense(dtype); + break; default: ARCHER_LOG_FATAL("ExpertDispatcher::ExpertDispatcher: unknown expert type ", expert_type); @@ -107,11 +110,20 @@ void ExpertDispatcher::EnqueueExpert(int layer_idx, int expert_idx, int gpu_id, void ExpertDispatcher::Enqueue(const CallArgs& args) { std::lock_guard lock(input_mutex_); - int layer_idx = args.layer_idx; int expert_idx = args.expert_idx; auto expert_node = experts_[expert_idx][layer_idx]; + // if (!expert_node->node->mutex.try_lock()) { + // ARCHER_LOG_WARN("ExpertDispatcher::Enqueue: mutex try_lock failed (expert_idx ", + // expert_idx, + // " layer_idx ", + // layer_idx, + // "node ", + // expert_node->node->str(), + // ")"); + // return; + // } expert_node->node->mutex.lock(); expert_node->node->last_access_time = MCIROSECONDS_SINCE_EPOCH; @@ -273,10 +285,11 @@ void ExpertDispatcher::GPUFetchFunc(int gpu_id) case NLLB_MOE_DENSE_ACT_DENSE: case FSGPT_MOE_DENSE_ACT_DENSE: case MIXTRAL_MOE_DENSE_ACT_DENSE: + case DEEPSEEK_MOE_DENSE_ACT_DENSE: input = hidden_states_.index({token_indices}).to(expert_node->node->device); break; default: - ARCHER_LOG_FATAL("ExpertDispatcher::ExpertDispatcher: unknown expert type ", + ARCHER_LOG_FATAL("ExpertDispatcher::expert_type: unknown expert type ", expert_type); } @@ -364,8 +377,12 @@ void ExpertDispatcher::GPUExecFunc(int gpu_id) output = reinterpret_cast(expert_module) ->forward(args.hidden_states); break; + case DEEPSEEK_MOE_DENSE_ACT_DENSE: + output = reinterpret_cast(expert_module) + ->forward(args.hidden_states); + break; default: - ARCHER_LOG_FATAL("ExpertDispatcher::ExpertDispatcher: unknown expert type", + ARCHER_LOG_FATAL("ExpertDispatcher::GPUExecFunc: unknown expert type", expert_type); } @@ -395,8 +412,8 @@ void ExpertDispatcher::GPUExecFunc(int gpu_id) void ExpertDispatcher::OutputFunc(ExecArgs args, torch::Tensor output, int gpu_id) { - c10::cuda::CUDAStream stream = c10::cuda::getStreamFromExternal(out_streams_[gpu_id], gpu_id); - c10::cuda::CUDAStreamGuard guard(stream); + // c10::cuda::CUDAStream stream = c10::cuda::getStreamFromExternal(out_streams_[gpu_id], + // gpu_id); c10::cuda::CUDAStreamGuard guard(stream); auto output_device = (args.out_gpu_id < 0) ? CPU_DEVICE : CUDA_DEVICE(args.out_gpu_id); torch::Tensor output_tensor = output.to(output_device).to(args.out_dtype); @@ -430,7 +447,7 @@ void ExpertDispatcher::OutputFunc(ExecArgs args, torch::Tensor output, int gpu_i args.hit, ")"); } - stream.synchronize(); + // stream.synchronize(); pending_.fetch_sub(1); } diff --git a/core/parallel/expert_module.cpp b/core/parallel/expert_module.cpp index 254cc28..cc643eb 100644 --- a/core/parallel/expert_module.cpp +++ b/core/parallel/expert_module.cpp @@ -156,23 +156,52 @@ torch::Tensor MixtralMoEDenseActDense::forward(torch::Tensor hidden_states) current_hidden_states = self.w2(current_hidden_states) return current_hidden_states */ - int w1_nan = torch::sum(torch::isnan(w1)).item(); - int w2_nan = torch::sum(torch::isnan(w2)).item(); - int w3_nan = torch::sum(torch::isnan(w3)).item(); - int hidden_states_nan = torch::sum(torch::isnan(hidden_states)).item(); + // int w1_nan = torch::sum(torch::isnan(w1)).item(); + // int w2_nan = torch::sum(torch::isnan(w2)).item(); + // int w3_nan = torch::sum(torch::isnan(w3)).item(); + // int hidden_states_nan = torch::sum(torch::isnan(hidden_states)).item(); // std::cout << "MixtralMoEDenseActDense w1 " << w1_nan << " w2 " << w2_nan << " w3 " << w3_nan // << " hidden_states " << hidden_states_nan << std::endl; - assert(w1_nan == 0); - assert(w2_nan == 0); - assert(w3_nan == 0); - assert(hidden_states_nan == 0); + // assert(w1_nan == 0); + // assert(w2_nan == 0); + // assert(w3_nan == 0); + // assert(hidden_states_nan == 0); + + // auto gate_proj = torch::silu(torch::matmul(hidden_states, w1.transpose(0, 1))); + // auto up_proj = torch::matmul(hidden_states, w3.transpose(0, 1)); + // auto down_proj = torch::matmul(gate_proj * up_proj, w2.transpose(0, 1)); return torch::matmul(torch::silu(torch::matmul(hidden_states, w1.transpose(0, 1))) * torch::matmul(hidden_states, w3.transpose(0, 1)), w2.transpose(0, 1)); } +DeepSeekMoEDenseActDense::DeepSeekMoEDenseActDense(int dtype) +{ + auto tensor_dtype = dtype_to_torch(dtype); + auto options = torch::TensorOptions().dtype(tensor_dtype).device(torch::kCPU); + gate_proj = register_parameter("gate_proj", torch::zeros({1}, options)); + up_proj = register_parameter("up_proj", torch::zeros({1}, options)); + down_proj = register_parameter("down_proj", torch::zeros({1}, options)); +} + +void DeepSeekMoEDenseActDense::SetTensorsFromBlob(void* ptr, + const std::vector& tensor_ids, + const torch::Device& device) +{ + gate_proj = kTensorIndex->find(tensor_ids[0])->second.tensor; + up_proj = kTensorIndex->find(tensor_ids[1])->second.tensor; + down_proj = kTensorIndex->find(tensor_ids[2])->second.tensor; +} + +torch::Tensor DeepSeekMoEDenseActDense::forward(torch::Tensor hidden_states) +{ + return torch::matmul(torch::silu(torch::matmul(hidden_states, gate_proj.transpose(0, 1))) * + torch::matmul(hidden_states, up_proj.transpose(0, 1)), + down_proj.transpose(0, 1)); +} + void ExpertNode::SetTensorsFromBlob(const torch::Device& device) { int expert_type = this->expert_type; @@ -197,6 +226,10 @@ void ExpertNode::SetTensorsFromBlob(const torch::Device& device) reinterpret_cast(module)->SetTensorsFromBlob( node->device_memory_ptr, node->tensor_ids, device); break; + case DEEPSEEK_MOE_DENSE_ACT_DENSE: + reinterpret_cast(module)->SetTensorsFromBlob( + node->device_memory_ptr, node->tensor_ids, device); + break; default: assert(false); } } diff --git a/core/parallel/expert_module.h b/core/parallel/expert_module.h index e91d862..035b1b1 100644 --- a/core/parallel/expert_module.h +++ b/core/parallel/expert_module.h @@ -15,6 +15,7 @@ #define NLLB_MOE_DENSE_ACT_DENSE 2 #define FSGPT_MOE_DENSE_ACT_DENSE 3 #define MIXTRAL_MOE_DENSE_ACT_DENSE 4 +#define DEEPSEEK_MOE_DENSE_ACT_DENSE 5 #define DTYPE_BFLOAT16 0 #define DTYPE_FLOAT32 1 @@ -78,6 +79,16 @@ struct MixtralMoEDenseActDense : public torch::nn::Module, public ModuleUtils { const torch::Device& device) override; }; +struct DeepSeekMoEDenseActDense : public torch::nn::Module, public ModuleUtils { + DeepSeekMoEDenseActDense(int dtype); + torch::Tensor forward(torch::Tensor hidden_states); + torch::Tensor gate_proj, up_proj, down_proj; + + void SetTensorsFromBlob(void* ptr, + const std::vector& tensor_ids, + const torch::Device& device) override; +}; + struct ExpertNode { NodePtr node; torch::nn::Module* module; diff --git a/examples/interface_example.py b/examples/interface_example.py index fd59474..bd73aa6 100644 --- a/examples/interface_example.py +++ b/examples/interface_example.py @@ -97,3 +97,6 @@ do_sample=False, **custom_kwargs, ) + +# sudo CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=1,2 numactl --membind=0 /mnt/data/xly/.conda/envs/moe-infinity/bin/python interface_example.py --model_name_or_path mistralai/Mixtral-8x7B-Instruct-v0.1 --offload_dir /mnt/raid0nvme1/xly/test-data +# sudo CUDA_LAUNCH_BLOCKING=1 CUDA_VISIBLE_DEVICES=1,2 numactl --membind=0 gdb --args /mnt/data/xly/.conda/envs/moe-infinity/bin/python /home/xly/MoE-Infinity/examples/interface_example.py --model_name_or_path deepseek-ai/DeepSeek-V2-Lite --offload_dir /mnt/raid0nvme1/xly/test-data diff --git a/moe_infinity/common/constants.py b/moe_infinity/common/constants.py index f6792b9..a45c2a9 100644 --- a/moe_infinity/common/constants.py +++ b/moe_infinity/common/constants.py @@ -30,7 +30,7 @@ "mixtral": 4, "grok": 4, "arctic": 4, - "deepseek": 4, + "deepseek": 5, } diff --git a/moe_infinity/distributed/expert_prefetcher.py b/moe_infinity/distributed/expert_prefetcher.py index 0bec6ea..8c88b32 100644 --- a/moe_infinity/distributed/expert_prefetcher.py +++ b/moe_infinity/distributed/expert_prefetcher.py @@ -49,7 +49,6 @@ def prefetch_experts(self, layer_id, expert_matrix): expert_list, key=lambda x: x[1], reverse=True ) tensor_ids = [x[0] for x in ordered_expert_list] - device_list = self.device_map_manager.get_target_device(tensor_ids) if len(tensor_ids) > 0: diff --git a/moe_infinity/memory/expert_prefetcher.py b/moe_infinity/memory/expert_prefetcher.py index 02c5fc5..c3b237e 100644 --- a/moe_infinity/memory/expert_prefetcher.py +++ b/moe_infinity/memory/expert_prefetcher.py @@ -4,6 +4,7 @@ # TorchMoE Team +import numpy as np from transformers import PretrainedConfig from moe_infinity.utils import parse_moe_param @@ -37,7 +38,7 @@ def prefetch_experts(self, layer_id, expert_matrix): expert_list, key=lambda x: x[1], reverse=True ) tensor_ids = [x[0] for x in ordered_expert_list] - + assert len(np.unique(tensor_ids)) == len(tensor_ids) self.archer_engine.replace_cache_candidates(tensor_ids) for tensor_id in tensor_ids: gpu_id = self.archer_engine.get_node_default_device([tensor_id]) diff --git a/moe_infinity/models/deepseek.py b/moe_infinity/models/deepseek.py index 6eb5061..f2aab9f 100644 --- a/moe_infinity/models/deepseek.py +++ b/moe_infinity/models/deepseek.py @@ -1,7 +1,8 @@ from typing import Dict, Optional, Tuple + import torch -import torch.nn.functional as F import torch.nn as nn +import torch.nn.functional as F from .modeling_deepseek import DeepseekV2MLP, MoEGate @@ -26,12 +27,13 @@ def __init__(self, config): ) self.gate = MoEGate(config) if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts + intermediate_size = ( + config.moe_intermediate_size * config.n_shared_experts + ) self.shared_experts = DeepseekV2MLP( config=config, intermediate_size=intermediate_size ) - - + self.archer_tracer = None self.archer_engine = None self.expert_tensor_ids: Dict[int, int] = None @@ -42,49 +44,92 @@ def forward(self, hidden_states): topk_idx, topk_weight, aux_loss = self.gate(hidden_states) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + print("topk_idx", topk_idx.shape) + print("topk_weight", topk_weight.shape) + print(self.config.n_routed_experts, self.config.num_experts_per_tok) + cnts = topk_idx.new_zeros((topk_idx.shape[0], len(self.experts))) cnts.scatter_(1, topk_idx, 1) tokens_per_expert = cnts.sum(dim=0) idxs = topk_idx.view(-1).argsort() - sorted_tokens = hidden_states[idxs // topk_idx.shape[1]] - + # sorted_tokens = hidden_states[idxs // topk_idx.shape[1]] + tokens_per_expert = tokens_per_expert.cpu().numpy() - - batch_size, sequence_length, _ = orig_shape - router_mask = F.one_hot(topk_idx, num_classes=self.config.n_routed_experts) - - # print("router_mask", router_mask.shape) - - expert_index = topk_idx.reshape(batch_size, sequence_length, self.config.num_experts_per_tok) + + batch_size, sequence_length, hidden_dim = orig_shape + router_mask = F.one_hot( + topk_idx, num_classes=self.config.n_routed_experts + ) + routing_weights_mask = (topk_weight[:, :, None] * router_mask).permute( + 0, 2, 1 + ) + routing_weights_mask = torch.sum(routing_weights_mask, dim=-1) + router_mask = router_mask.permute(0, 2, 1) + + # use logical or to merge last dimension + for i in range(self.config.num_experts_per_tok): + router_mask[:, :, 0] = torch.logical_or( + router_mask[:, :, 0], router_mask[:, :, i] + ) + router_mask = router_mask[:, :, 0] + print("router_mask", router_mask.shape) + print("routing_weights_mask", routing_weights_mask.shape) + + expert_index = topk_idx.reshape( + batch_size, sequence_length, self.config.num_experts_per_tok + ) for i in range(batch_size): seq_id = self.seq_id_list[i] - expert_matrix = self.expert_predictor.predict(seq_id, expert_index[i], self.layer_id) - self.expert_prefetcher.prefetch_experts(self.layer_id, expert_matrix) - - outputs = [] - start_idx = 0 - for i, num_tokens in enumerate(tokens_per_expert): - end_idx = start_idx + num_tokens - if num_tokens == 0: - continue - expert = self.experts[i] - tokens_for_this_expert = sorted_tokens[start_idx:end_idx] - expert_out = expert(tokens_for_this_expert) - outputs.append(expert_out.to(hidden_states.device)) - start_idx = end_idx - - outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) - - new_x = torch.empty_like(outs) - new_x[idxs] = outs - y = ( - new_x.view(*topk_idx.shape, -1) - .type(topk_weight.dtype) - .mul_(topk_weight.unsqueeze(dim=-1)) - .sum(dim=1) - .type(new_x.dtype) + expert_matrix = self.expert_predictor.predict( + seq_id, expert_index[i], self.layer_id + ) + self.expert_prefetcher.prefetch_experts( + self.layer_id, expert_matrix + ) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + results = self.expert_executor.dispatch_local( + hidden_states, router_mask, self.layer_id ) + for output, _, idx, _ in results: + token_indices = router_mask[:, idx].bool() + final_hidden_states[token_indices, :] += ( + output.to(routing_weights_mask.device) + * routing_weights_mask[token_indices, idx][:, None] + ) + + # outputs = [] + # start_idx = 0 + # for i, num_tokens in enumerate(tokens_per_expert): + # end_idx = start_idx + num_tokens + # if num_tokens == 0: + # continue + # expert = self.experts[i] + # tokens_for_this_expert = sorted_tokens[start_idx:end_idx] + # expert_out = expert(tokens_for_this_expert) + # outputs.append(expert_out.to(hidden_states.device)) + # start_idx = end_idx + + # outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + # new_x = torch.empty_like(outs) + # new_x[idxs] = outs + # y = ( + # new_x.view(*topk_idx.shape, -1) + # .type(topk_weight.dtype) + # .mul_(topk_weight.unsqueeze(dim=-1)) + # .sum(dim=1) + # .type(new_x.dtype) + # ) + final_hidden_states = final_hidden_states.view( + batch_size, sequence_length, hidden_dim + ) if self.config.n_shared_experts is not None: - y = y + self.shared_experts(identity) - return y \ No newline at end of file + final_hidden_states = final_hidden_states + self.shared_experts( + identity + ) + return final_hidden_states diff --git a/moe_infinity/models/modeling_deepseek/__init__.py b/moe_infinity/models/modeling_deepseek/__init__.py index d8e8499..2e6b972 100644 --- a/moe_infinity/models/modeling_deepseek/__init__.py +++ b/moe_infinity/models/modeling_deepseek/__init__.py @@ -1,3 +1,8 @@ from .configuration_deepseek import DeepseekV2Config -from .modeling_deepseek import DeepseekV2ForCausalLM, DeepseekV2MLP, MoEGate, DeepseekV2MoE -from .tokenization_deepseek_fast import DeepseekTokenizerFast \ No newline at end of file +from .modeling_deepseek import ( + DeepseekV2ForCausalLM, + DeepseekV2MLP, + DeepseekV2MoE, + MoEGate, +) +from .tokenization_deepseek_fast import DeepseekTokenizerFast diff --git a/moe_infinity/models/modeling_deepseek/configuration_deepseek.py b/moe_infinity/models/modeling_deepseek/configuration_deepseek.py index 82e0f5d..d3b5148 100644 --- a/moe_infinity/models/modeling_deepseek/configuration_deepseek.py +++ b/moe_infinity/models/modeling_deepseek/configuration_deepseek.py @@ -4,6 +4,8 @@ logger = logging.get_logger(__name__) DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} + + class DeepseekV2Config(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`DeepseekV2Model`]. It is used to instantiate an DeepSeek @@ -117,29 +119,29 @@ def __init__( vocab_size=102400, hidden_size=4096, intermediate_size=11008, - moe_intermediate_size = 1407, + moe_intermediate_size=1407, num_hidden_layers=30, num_attention_heads=32, num_key_value_heads=32, - n_shared_experts = None, - n_routed_experts = None, - ep_size = 1, - routed_scaling_factor = 1.0, - kv_lora_rank = 512, - q_lora_rank = 1536, - qk_rope_head_dim = 64, - v_head_dim = 128, - qk_nope_head_dim = 128, - topk_method = 'gready', - n_group = None, - topk_group = None, - num_experts_per_tok = None, - moe_layer_freq = 1, - first_k_dense_replace = 0, - norm_topk_prob = False, - scoring_func = 'softmax', - aux_loss_alpha = 0.001, - seq_aux = True, + n_shared_experts=None, + n_routed_experts=None, + ep_size=1, + routed_scaling_factor=1.0, + kv_lora_rank=512, + q_lora_rank=1536, + qk_rope_head_dim=64, + v_head_dim=128, + qk_nope_head_dim=128, + topk_method="gready", + n_group=None, + topk_group=None, + num_experts_per_tok=None, + moe_layer_freq=1, + first_k_dense_replace=0, + norm_topk_prob=False, + scoring_func="softmax", + aux_loss_alpha=0.001, + seq_aux=True, hidden_act="silu", max_position_embeddings=2048, initializer_range=0.02, @@ -203,4 +205,4 @@ def __init__( eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs, - ) \ No newline at end of file + ) diff --git a/moe_infinity/models/modeling_deepseek/modeling_deepseek.py b/moe_infinity/models/modeling_deepseek/modeling_deepseek.py index 847a458..5ea7c8d 100644 --- a/moe_infinity/models/modeling_deepseek/modeling_deepseek.py +++ b/moe_infinity/models/modeling_deepseek/modeling_deepseek.py @@ -1,4 +1,3 @@ -# coding=utf-8 # Copyright 2023 DeepSeek-AI and The HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -17,17 +16,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" PyTorch DeepSeek model.""" +"""PyTorch DeepSeek model.""" + import math import warnings from typing import List, Optional, Tuple, Union +import numpy as np import torch +import torch.distributed as dist import torch.nn.functional as F import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss - from transformers.activations import ACT2FN from transformers.cache_utils import Cache, DynamicCache from transformers.modeling_attn_mask_utils import ( @@ -54,13 +55,16 @@ replace_return_docstrings, ) from transformers.utils.import_utils import is_torch_fx_available + from .configuration_deepseek import DeepseekV2Config -import torch.distributed as dist -import numpy as np if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func - from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + from flash_attn.bert_padding import ( # noqa + index_first_axis, + pad_input, + unpad_input, + ) # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. @@ -69,7 +73,9 @@ if not is_torch_greater_or_equal_than_1_13: import torch.fx - _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + _prepare_4d_causal_attention_mask = torch.fx.wrap( + _prepare_4d_causal_attention_mask + ) logger = logging.get_logger(__name__) @@ -104,7 +110,9 @@ def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) return self.weight * hidden_states.to(input_dtype) @@ -112,14 +120,17 @@ def forward(self, hidden_states): class DeepseekV2RotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__( + self, dim, max_position_embeddings=2048, base=10000, device=None + ): super().__init__() self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / ( - self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + self.base + ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) @@ -140,13 +151,19 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, self.inv_freq.to(t.device)) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + self.register_buffer( + "cos_cached", emb.cos().to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin().to(dtype), persistent=False + ) def forward(self, x, seq_len=None): # x: [bs, num_attention_heads, seq_len, head_size] if self.max_seq_len_cached is None or seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + self._set_cos_sin_cache( + seq_len=seq_len, device=x.device, dtype=x.dtype + ) return ( self.cos_cached[:seq_len].to(dtype=x.dtype), @@ -179,8 +196,12 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + self.register_buffer( + "cos_cached", emb.cos().to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin().to(dtype), persistent=False + ) # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->DeepseekV2 @@ -207,7 +228,8 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) inv_freq = 1.0 / ( - base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) + base + ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim) ) self.register_buffer("inv_freq", inv_freq, persistent=False) @@ -218,17 +240,21 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freqs = torch.outer(t, self.inv_freq) # Different from paper, but it uses a different permutation in order to obtain the same calculation emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + self.register_buffer( + "cos_cached", emb.cos().to(dtype), persistent=False + ) + self.register_buffer( + "sin_cached", emb.sin().to(dtype), persistent=False + ) # Inverse dim formula to find dim based on number of rotations def yarn_find_correction_dim( num_rotations, dim, base=10000, max_position_embeddings=2048 ): - return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( - 2 * math.log(base) - ) + return ( + dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi)) + ) / (2 * math.log(base)) # Find dim range bounds based on rotations @@ -260,7 +286,6 @@ def yarn_linear_ramp_mask(min, max, dim): class DeepseekV2YarnRotaryEmbedding(DeepseekV2RotaryEmbedding): - def __init__( self, dim, @@ -288,12 +313,18 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): freq_extra = 1.0 / ( self.base - ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ** ( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) + / dim + ) ) freq_inter = 1.0 / ( self.scaling_factor * self.base - ** (torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim) + ** ( + torch.arange(0, dim, 2, dtype=torch.float32, device=device) + / dim + ) ) low, high = yarn_find_correction_range( @@ -375,18 +406,30 @@ class DeepseekV2MLP(nn.Module): def __init__(self, config, hidden_size=None, intermediate_size=None): super().__init__() self.config = config - self.hidden_size = config.hidden_size if hidden_size is None else hidden_size + self.hidden_size = ( + config.hidden_size if hidden_size is None else hidden_size + ) self.intermediate_size = ( - config.intermediate_size if intermediate_size is None else intermediate_size + config.intermediate_size + if intermediate_size is None + else intermediate_size ) - self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) - self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) + self.gate_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.up_proj = nn.Linear( + self.hidden_size, self.intermediate_size, bias=False + ) + self.down_proj = nn.Linear( + self.intermediate_size, self.hidden_size, bias=False + ) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): - down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + down_proj = self.down_proj( + self.act_fn(self.gate_proj(x)) * self.up_proj(x) + ) return down_proj @@ -422,7 +465,9 @@ def forward(self, hidden_states): ### compute gating score hidden_states = hidden_states.view(-1, h) logits = F.linear( - hidden_states.type(torch.float32), self.weight.type(torch.float32), None + hidden_states.type(torch.float32), + self.weight.type(torch.float32), + None, ) if self.scoring_func == "softmax": scores = logits.softmax(dim=-1, dtype=torch.float32) @@ -442,15 +487,15 @@ def forward(self, hidden_states): ) # [n, n_group] group_idx = torch.topk( group_scores, k=self.topk_group, dim=-1, sorted=False - )[ - 1 - ] # [n, top_k_group] + )[1] # [n, top_k_group] group_mask = torch.zeros_like(group_scores) # [n, n_group] group_mask.scatter_(1, group_idx, 1) # [n, n_group] score_mask = ( group_mask.unsqueeze(-1) .expand( - bsz * seq_len, self.n_group, self.n_routed_experts // self.n_group + bsz * seq_len, + self.n_group, + self.n_routed_experts // self.n_group, ) .reshape(bsz * seq_len, -1) ) # [n, e] @@ -479,14 +524,17 @@ def forward(self, hidden_states): ce.scatter_add_( 1, topk_idx_for_aux_loss, - torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device), + torch.ones( + bsz, seq_len * aux_topk, device=hidden_states.device + ), ).div_(seq_len * aux_topk / self.n_routed_experts) aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum( dim=1 ).mean() * self.alpha else: mask_ce = F.one_hot( - topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts + topk_idx_for_aux_loss.view(-1), + num_classes=self.n_routed_experts, ) ce = mask_ce.float().mean(0) Pi = scores_for_aux.mean(0) @@ -514,7 +562,9 @@ def forward(ctx, x, loss): def backward(ctx, grad_output): grad_loss = None if ctx.required_aux_loss: - grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device) + grad_loss = torch.ones( + 1, dtype=ctx.dtype, device=grad_output.device + ) return grad_output, grad_loss @@ -537,7 +587,8 @@ def __init__(self, config): [ ( DeepseekV2MLP( - config, intermediate_size=config.moe_intermediate_size + config, + intermediate_size=config.moe_intermediate_size, ) if i >= self.ep_rank * self.experts_per_rank and i < (self.ep_rank + 1) * self.experts_per_rank @@ -560,7 +611,9 @@ def __init__(self, config): ) self.gate = MoEGate(config) if config.n_shared_experts is not None: - intermediate_size = config.moe_intermediate_size * config.n_shared_experts + intermediate_size = ( + config.moe_intermediate_size * config.n_shared_experts + ) self.shared_experts = DeepseekV2MLP( config=config, intermediate_size=intermediate_size ) @@ -577,12 +630,18 @@ def forward(self, hidden_states): ) y = torch.empty_like(hidden_states) for i, expert in enumerate(self.experts): - y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i]) - y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + y[flat_topk_idx == i] = expert( + hidden_states[flat_topk_idx == i] + ) + y = ( + y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1) + ).sum(dim=1) y = y.to(hidden_states.dtype).view(*orig_shape) y = AddAuxiliaryLoss.apply(y, aux_loss) else: - y = self.moe_infer(hidden_states, topk_idx, topk_weight).view(*orig_shape) + y = self.moe_infer(hidden_states, topk_idx, topk_weight).view( + *orig_shape + ) if self.config.n_shared_experts is not None: y = y + self.shared_experts(identity) return y @@ -596,7 +655,9 @@ def moe_infer(self, x, topk_ids, topk_weight): sorted_tokens = x[idxs // topk_ids.shape[1]] sorted_tokens_shape = sorted_tokens.shape if self.ep_size > 1: - tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1) + tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum( + dim=1 + ) tokens_per_expert_group = tokens_per_expert.new_empty( tokens_per_expert.shape[0] ) @@ -609,7 +670,8 @@ def moe_infer(self, x, topk_ids, topk_weight): .tolist() ) gathered_tokens = sorted_tokens.new_empty( - tokens_per_expert_group.sum(dim=0).cpu().item(), sorted_tokens.shape[1] + tokens_per_expert_group.sum(dim=0).cpu().item(), + sorted_tokens.shape[1], ) input_split_sizes = tokens_per_ep_rank.cpu().numpy().tolist() dist.all_to_all( @@ -619,7 +681,9 @@ def moe_infer(self, x, topk_ids, topk_weight): tokens_per_expert_post_gather = tokens_per_expert_group.view( self.ep_size, self.experts_per_rank ).sum(dim=0) - gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32) + gatherd_idxs = np.zeros( + shape=(gathered_tokens.shape[0],), dtype=np.int32 + ) s = 0 for i, k in enumerate(tokens_per_expert_group.cpu().numpy()): gatherd_idxs[s : s + k] = i % self.experts_per_rank @@ -641,7 +705,11 @@ def moe_infer(self, x, topk_ids, topk_weight): outputs.append(expert_out) start_idx = end_idx - outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0) + outs = ( + torch.cat(outputs, dim=0) + if len(outputs) + else sorted_tokens.new_empty(0) + ) if self.ep_size > 1: new_x = torch.empty_like(outs) new_x[gatherd_idxs] = outs @@ -676,14 +744,18 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: hidden_states = hidden_states[:, :, None, :, :].expand( batch, num_key_value_heads, n_rep, slen, head_dim ) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + return hidden_states.reshape( + batch, num_key_value_heads * n_rep, slen, head_dim + ) # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->DeepseekV2 class DeepseekV2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: DeepseekV2Config, layer_idx: Optional[int] = None): + def __init__( + self, config: DeepseekV2Config, layer_idx: Optional[int] = None + ): super().__init__() self.config = config self.layer_idx = layer_idx @@ -812,7 +884,9 @@ def forward( output_attentions: bool = False, use_cache: bool = False, **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[ + torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]] + ]: if "padding_mask" in kwargs: warnings.warn( "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" @@ -835,7 +909,12 @@ def forward( k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .view( + bsz, + q_len, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) .transpose(1, 2) ) @@ -850,12 +929,16 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_usable_length( + kv_seq_len, self.layer_idx + ) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states = k_pe.new_empty( + bsz, self.num_heads, q_len, self.q_head_dim + ) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe @@ -869,7 +952,8 @@ def forward( ) attn_weights = ( - torch.matmul(query_states, key_states.transpose(2, 3)) * self.softmax_scale + torch.matmul(query_states, key_states.transpose(2, 3)) + * self.softmax_scale ) if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): @@ -902,7 +986,9 @@ def forward( attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.v_head_dim) + attn_output = attn_output.reshape( + bsz, q_len, self.num_heads * self.v_head_dim + ) attn_output = self.o_proj(attn_output) @@ -924,9 +1010,11 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + self._flash_attn_uses_top_left_mask = ( + not is_flash_attn_greater_or_equal_2_10() + ) def forward( self, @@ -937,7 +1025,9 @@ def forward( output_attentions: bool = False, use_cache: bool = False, **kwargs, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + ) -> Tuple[ + torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]] + ]: # DeepseekV2FlashAttention2 attention does not support output_attentions if "padding_mask" in kwargs: warnings.warn( @@ -970,7 +1060,12 @@ def forward( k_pe = k_pe.view(bsz, q_len, 1, self.qk_rope_head_dim).transpose(1, 2) kv = ( self.kv_b_proj(self.kv_a_layernorm(compressed_kv)) - .view(bsz, q_len, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) + .view( + bsz, + q_len, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) .transpose(1, 2) ) @@ -981,12 +1076,16 @@ def forward( kv_seq_len = value_states.shape[-2] if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + kv_seq_len += past_key_value.get_usable_length( + kv_seq_len, self.layer_idx + ) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) q_pe, k_pe = apply_rotary_pos_emb(q_pe, k_pe, cos, sin, position_ids) - query_states = k_pe.new_empty(bsz, self.num_heads, q_len, self.q_head_dim) + query_states = k_pe.new_empty( + bsz, self.num_heads, q_len, self.q_head_dim + ) query_states[:, :, :, : self.qk_nope_head_dim] = q_nope query_states[:, :, :, self.qk_nope_head_dim :] = q_pe @@ -995,7 +1094,9 @@ def forward( key_states[:, :, :, self.qk_nope_head_dim :] = k_pe if self.q_head_dim != self.v_head_dim: - value_states = F.pad(value_states, [0, self.q_head_dim - self.v_head_dim]) + value_states = F.pad( + value_states, [0, self.q_head_dim - self.v_head_dim] + ) if past_key_value is not None: cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models @@ -1109,7 +1210,11 @@ def _flash_attention_forward( cu_seq_lens, max_seq_lens, ) = self._upad_input( - query_states, key_states, value_states, attention_mask, query_length + query_states, + key_states, + value_states, + attention_mask, + query_length, ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens @@ -1146,20 +1251,28 @@ def _flash_attention_forward( def _upad_input( self, query_layer, key_layer, value_layer, attention_mask, query_length ): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data( + attention_mask + ) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( - key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + key_layer.reshape( + batch_size * kv_seq_len, num_key_value_heads, head_dim + ), indices_k, ) value_layer = index_first_axis( - value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), + value_layer.reshape( + batch_size * kv_seq_len, num_key_value_heads, head_dim + ), indices_k, ) if query_length == kv_seq_len: query_layer = index_first_axis( - query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), + query_layer.reshape( + batch_size * kv_seq_len, self.num_heads, head_dim + ), indices_k, ) cu_seqlens_q = cu_seqlens_k @@ -1175,8 +1288,8 @@ def _upad_input( else: # The -q_len: slice assumes left padding. attention_mask = attention_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input( - query_layer, attention_mask + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = ( + unpad_input(query_layer, attention_mask) ) return ( @@ -1421,8 +1534,12 @@ def __init__(self, config: DeepseekV2Config): for layer_idx in range(config.num_hidden_layers) ] ) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self.norm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self._use_flash_attention_2 = ( + config._attn_implementation == "flash_attention_2" + ) + self.norm = DeepseekV2RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -1457,10 +1574,14 @@ def forward( if output_hidden_states is not None else self.config.output_hidden_states ) - use_cache = use_cache if use_cache is not None else self.config.use_cache + use_cache = ( + use_cache if use_cache is not None else self.config.use_cache + ) return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + return_dict + if return_dict is not None + else self.config.use_return_dict ) # retrieve input_ids and inputs_embeds @@ -1473,7 +1594,9 @@ def forward( elif inputs_embeds is not None: batch_size, seq_length = inputs_embeds.shape[:2] else: - raise ValueError("You have to specify either input_ids or inputs_embeds") + raise ValueError( + "You have to specify either input_ids or inputs_embeds" + ) if self.gradient_checkpointing and self.training: if use_cache: @@ -1486,11 +1609,19 @@ def forward( if use_cache: use_legacy_cache = not isinstance(past_key_values, Cache) if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + past_key_values = DynamicCache.from_legacy_cache( + past_key_values + ) + past_key_values_length = past_key_values.get_usable_length( + seq_length + ) if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device + device = ( + input_ids.device + if input_ids is not None + else inputs_embeds.device + ) position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, @@ -1553,7 +1684,9 @@ def forward( hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache = layer_outputs[2 if output_attentions else 1] + next_decoder_cache = layer_outputs[ + 2 if output_attentions else 1 + ] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1574,7 +1707,12 @@ def forward( if not return_dict: return tuple( v - for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] + for v in [ + hidden_states, + next_cache, + all_hidden_states, + all_self_attns, + ] if v is not None ) return BaseModelOutputWithPast( @@ -1592,7 +1730,9 @@ def __init__(self, config): super().__init__(config) self.model = DeepseekV2Model(config) self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.lm_head = nn.Linear( + config.hidden_size, config.vocab_size, bias=False + ) # Initialize weights and apply final processing self.post_init() @@ -1668,7 +1808,9 @@ def forward( else self.config.output_hidden_states ) return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + return_dict + if return_dict is not None + else self.config.use_return_dict ) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) @@ -1732,13 +1874,15 @@ def prepare_inputs_for_generation( # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as # input) if ( attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1] ): - input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + input_ids = input_ids[ + :, -(attention_mask.shape[1] - past_length) : + ] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard # input_ids based on the past_length. elif past_length < input_ids.shape[1]: @@ -1842,7 +1986,9 @@ def forward( `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + return_dict + if return_dict is not None + else self.config.use_return_dict ) transformer_outputs = self.model( @@ -1873,7 +2019,10 @@ def forward( else: if input_ids is not None: sequence_lengths = ( - torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + torch.eq(input_ids, self.config.pad_token_id) + .int() + .argmax(-1) + - 1 ).to(logits.device) else: sequence_lengths = -1 diff --git a/moe_infinity/models/modeling_deepseek/tokenization_deepseek_fast.py b/moe_infinity/models/modeling_deepseek/tokenization_deepseek_fast.py index d243771..05aa411 100644 --- a/moe_infinity/models/modeling_deepseek/tokenization_deepseek_fast.py +++ b/moe_infinity/models/modeling_deepseek/tokenization_deepseek_fast.py @@ -1,11 +1,9 @@ from typing import List, Optional, Union - from transformers.models.llama import LlamaTokenizerFast class DeepseekTokenizerFast(LlamaTokenizerFast): - def convert_ids_to_tokens( self, ids: Union[int, List[int]], skip_special_tokens: bool = False ) -> Union[str, List[str]]: @@ -32,7 +30,7 @@ def convert_ids_to_tokens( token = self._tokenizer.id_to_token(index) tokens.append(token if token is not None else "") return tokens - + def _convert_id_to_token(self, index: int) -> Optional[str]: token = self._tokenizer.id_to_token(int(index)) return token if token is not None else "" diff --git a/moe_infinity/runtime/model_offload.py b/moe_infinity/runtime/model_offload.py index 241b9be..7022a7e 100644 --- a/moe_infinity/runtime/model_offload.py +++ b/moe_infinity/runtime/model_offload.py @@ -406,7 +406,7 @@ def archer_from_pretrained(cls, *args, **kwargs): self.checkpoint, "name_id_map.json" ) - model_name = args[0] + self.model_name = model_name = args[0] # if "arctic" in model_name: # self.config = ArcticConfig.from_pretrained(*args, **kwargs) # else: @@ -883,6 +883,9 @@ def gen_args_hook( ) expert_layer_id = 0 + if "deepseek" in self.model_name: + expert_layer_id = self.config.first_k_dense_replace + output_device_index = None for key, tensors in topo: # print(key, tensors)