From 08ded213ba1e212b399be58d3ec3bbe135f3e793 Mon Sep 17 00:00:00 2001 From: Leyang Xue Date: Sun, 5 May 2024 16:06:23 +0100 Subject: [PATCH] Feature/expert parallel (#9) * add back expert parallel by id hash * add grok ep * fix mistral typo * accom cuda copy bug * sync after compute * fix:sync to make sure that input is ready --------- Co-authored-by: xly Co-authored-by: luzhan <513964121@qq.com> --- README.md | 4 +- core/aio/archer_prio_aio_handle.cpp | 2 +- core/model/model_topology.cpp | 4 +- core/parallel/expert_dispatcher.cpp | 4 + core/prefetch/archer_prefetch_handle.cpp | 4 +- core/trace/archer_tensor_tracer.cpp | 140 ---- core/trace/archer_tensor_tracer.h | 30 - core/trace/model_topology.cpp | 699 -------------------- core/trace/model_topology.h | 203 ------ core/utils/cuda_utils.cpp | 15 + core/utils/cuda_utils.h | 7 + examples/interface_example.py | 9 +- moe_infinity/distributed/expert_executor.py | 19 + moe_infinity/entrypoints/big_modeling.py | 3 +- moe_infinity/models/grok.py | 82 ++- moe_infinity/models/mixtral.py | 27 +- moe_infinity/models/nllb_moe.py | 22 +- moe_infinity/models/switch_transformers.py | 19 +- moe_infinity/runtime/model_offload.py | 9 +- requirements.txt | 4 +- 20 files changed, 155 insertions(+), 1151 deletions(-) delete mode 100644 core/trace/archer_tensor_tracer.cpp delete mode 100644 core/trace/archer_tensor_tracer.h delete mode 100644 core/trace/model_topology.cpp delete mode 100644 core/trace/model_topology.h diff --git a/README.md b/README.md index aff768a..b89569e 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ Note that: The open-sourced MoE-Infinity has been redesigned for making it Huggi Single GPU A5000 (24GB Memory), per-token-latency (seconds) for generation with a mixed dataset that includes [FLAN](https://huggingface.co/datasets/Muennighoff/flan), [BIG-Bench](https://huggingface.co/datasets/bigbench) and [MMLU](https://huggingface.co/datasets/lukaemon/mmlu) datasets. Lower per-token-latency is preferable. -| | switch-large-128 | NLLB-MoE-54B | Mixtral-7x8b | +| | switch-large-128 | NLLB-MoE-54B | Mixtral-8x7b | | :---: | :---: | :---: | :---: | | MoE-Infinity | *0.230* | *0.239* | *0.895* | | Accelerate | 1.043 | 3.071 | 6.633 | @@ -48,7 +48,7 @@ Lower per-token-latency is preferable. Single GPU A5000, throughput (token/s) for generation with batch size 32. Higher throughput is preferable. -| | switch-large-128 | NLLB-MoE-54B | Mixtral-7x8b | +| | switch-large-128 | NLLB-MoE-54B | Mixtral-8x7b | | :---: | :---: | :---: | :---: | | MoE-Infinity | *69.105* | *30.300* | *12.579* | | Accelerate | 5.788 | 4.344 | 1.245 | diff --git a/core/aio/archer_prio_aio_handle.cpp b/core/aio/archer_prio_aio_handle.cpp index 5e996f7..1d1d118 100644 --- a/core/aio/archer_prio_aio_handle.cpp +++ b/core/aio/archer_prio_aio_handle.cpp @@ -90,7 +90,7 @@ std::int64_t ArcherPrioAioHandle::Write(const std::string& filename, auto mem_type = IsDevicePointer(buffer) ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost; cudaHostAlloc(&write_buffer, num_bytes_aligned, cudaHostAllocDefault); - cudaMemcpy(write_buffer, buffer, num_bytes, mem_type); + CudaMemcpy(write_buffer, buffer, num_bytes, mem_type); auto callbacks = aio_context_.PrepIocbs(false, write_buffer, fd, kBlockSize, offset, num_bytes_aligned); auto io_request = std::make_shared(); diff --git a/core/model/model_topology.cpp b/core/model/model_topology.cpp index 69e32d7..0d0daf0 100644 --- a/core/model/model_topology.cpp +++ b/core/model/model_topology.cpp @@ -114,9 +114,9 @@ void Node::SetDevice(const torch::Device& target_device, auto start_time = MCIROSECONDS_SINCE_EPOCH; if (stream == nullptr) { - cudaMemcpy(device_memory_ptr, host_memory_ptr, byte_size, cudaMemcpyHostToDevice); + CudaMemcpy(device_memory_ptr, host_memory_ptr, byte_size, cudaMemcpyHostToDevice); } else { - cudaMemcpyAsync( + CudaMemcpyAsync( device_memory_ptr, host_memory_ptr, byte_size, cudaMemcpyHostToDevice, stream); cudaStreamSynchronize(stream); } diff --git a/core/parallel/expert_dispatcher.cpp b/core/parallel/expert_dispatcher.cpp index bf6c731..6e48677 100644 --- a/core/parallel/expert_dispatcher.cpp +++ b/core/parallel/expert_dispatcher.cpp @@ -332,6 +332,7 @@ void ExpertDispatcher::GPUExecFunc(int gpu_id) auto* expert_module = args.expert_node->module; int expert_type = expert_type_; + cudaStreamSynchronize(0); // make sure the input is ready try { switch (expert_type) { @@ -369,6 +370,8 @@ void ExpertDispatcher::GPUExecFunc(int gpu_id) ss << "]"; ARCHER_LOG_FATAL("ExpertDispatcher::GPUExecFunc", ss.str(), "expert_type", expert_type, e.what()); } + + stream.synchronize(); } (void)std::async(std::launch::async, @@ -414,6 +417,7 @@ void ExpertDispatcher::OutputFunc(ExecArgs args, torch::Tensor output, int gpu_i gpu_id, args.hit, ")"); } + stream.synchronize(); pending_.fetch_sub(1); } diff --git a/core/prefetch/archer_prefetch_handle.cpp b/core/prefetch/archer_prefetch_handle.cpp index 5960b10..5a0271c 100644 --- a/core/prefetch/archer_prefetch_handle.cpp +++ b/core/prefetch/archer_prefetch_handle.cpp @@ -11,7 +11,7 @@ #include "common/time.h" #include "memory/memory_pool.h" #include "task_scheduler.h" - +#include "utils/cuda_utils.h" #include "utils/archer_logger.h" ArcherPrefetchHandle::ArcherPrefetchHandle(const std::string& prefix, @@ -335,7 +335,7 @@ void ArcherPrefetchHandle::SetTensorDevice(torch::Tensor& tensor, torch::Device cudaSetDevice(device.index()); cudaMalloc(&device_ptr, byte_size); - cudaMemcpy(device_ptr, tensor.data_ptr(), byte_size, cudaMemcpyDeviceToDevice); + CudaMemcpy(device_ptr, tensor.data_ptr(), byte_size, cudaMemcpyDeviceToDevice); auto new_tensor = torch::from_blob( device_ptr, diff --git a/core/trace/archer_tensor_tracer.cpp b/core/trace/archer_tensor_tracer.cpp deleted file mode 100644 index dd0679b..0000000 --- a/core/trace/archer_tensor_tracer.cpp +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) TorchMoE. -// SPDX-License-Identifier: Apache-2.0 - -// TorchMoE Team - -#include "archer_tensor_tracer.h" - -#include -#include - -#include "utils/archer_logger.h" - -void ArcherTensorTracer::AddTrace(const std::uint32_t layer_id, - const std::vector buffer) -{ - assert(request_id_ != UINT64_MAX); - if (traces_.find(request_id_) == traces_.end()) { - traces_[request_id_] = std::vector>>(); - } - traces_[request_id_].push_back({layer_id, buffer}); - - max_layer_id_ = std::max(max_layer_id_, layer_id); -} - -void ArcherTensorTracer::ClearRequestID() -{ - if (traces_.size() > 1000) { traces_.erase(request_id_); } - request_id_ = UINT64_MAX; -} - -std::vector ArcherTensorTracer::GetCandidates(std::uint32_t layer_id) -{ - if (traces_.size() == 1) { return {}; } - - std::vector< - std::pair>>, double>> - candidates; - - for (auto& req_item : traces_[request_id_]) { - auto req_layer = req_item.first; - auto req_tensors = req_item.second; - - for (auto& hist_item : traces_) { - auto hist_req_id = hist_item.first; - auto value = hist_item.second; - - if (hist_req_id == request_id_) continue; - if (layer_id < value[0].first) continue; - - for (auto& value_item : value) { - auto hist_layer = value_item.first; - auto hist_tensors = value_item.second; - - if (req_layer != hist_layer) continue; - - std::vector overlap; - std::set_intersection(req_tensors.begin(), - req_tensors.end(), - hist_tensors.begin(), - hist_tensors.end(), - std::back_inserter(overlap)); - - std::vector total; - std::set_union(req_tensors.begin(), - req_tensors.end(), - hist_tensors.begin(), - hist_tensors.end(), - std::back_inserter(total)); - - double prob = - static_cast(overlap.size()) / total.size() / (layer_id - req_layer + 1); - - std::vector>> candidate_value; - for (auto& pair : value) { - if (pair.first > layer_id) { candidate_value.push_back(pair); } - } - - if (!candidate_value.empty()) { candidates.push_back({candidate_value, prob}); } - break; - } - } - } - - if (candidates.empty()) { return {}; } - - std::unordered_map> - tensor_probs; // , prob - for (auto& item : candidates) { - auto layer_tensors = item.first; - auto prob = item.second; - - for (auto& layer_item : layer_tensors) { - auto layer = layer_item.first; - auto tensors = layer_item.second; - if (tensor_probs.find(layer) == tensor_probs.end()) { - tensor_probs[layer] = std::unordered_map(); - } - auto& layer_tensor_probs = tensor_probs[layer]; - for (std::uint32_t tensor : tensors) { - if (layer_tensor_probs.find(tensor) == layer_tensor_probs.end()) { - layer_tensor_probs[tensor] = 0; - } - layer_tensor_probs[tensor] += prob; - } - } - } - - // find top 10 tensors for each layer id - std::vector tensor_ids; - for (auto& item : tensor_probs) { - // auto layer = item.first; - auto& layer_tensor_probs = item.second; - - // if (layer < layer_id + 2) continue; // skip the layers that are too close to the current - // layer - - std::vector> layer_tensor_probs_vec( - layer_tensor_probs.begin(), layer_tensor_probs.end()); - - layer_tensor_probs_vec.erase( - std::remove_if( - layer_tensor_probs_vec.begin(), - layer_tensor_probs_vec.end(), - [](const std::pair& pair) { return pair.second < 0.01; }), - layer_tensor_probs_vec.end()); - - if (layer_tensor_probs_vec.empty()) { continue; } - - std::sort(layer_tensor_probs_vec.begin(), - layer_tensor_probs_vec.end(), - [](const std::pair& a, - const std::pair& b) { return a.second > b.second; }); - - std::size_t width = 20; - for (std::uint32_t i = 0; i < std::min(layer_tensor_probs_vec.size(), width); ++i) { - tensor_ids.push_back(layer_tensor_probs_vec[i].first); - } - } - return tensor_ids; -} diff --git a/core/trace/archer_tensor_tracer.h b/core/trace/archer_tensor_tracer.h deleted file mode 100644 index c91ba94..0000000 --- a/core/trace/archer_tensor_tracer.h +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) TorchMoE. -// SPDX-License-Identifier: Apache-2.0 - -// TorchMoE Team - -#pragma once - -#include -#include -#include -#include - -class ArcherTensorTracer { -public: - // ArcherTensorTracer(const std::string& prefix); - - void SetRequestID(const std::uint64_t& request_id) { request_id_ = request_id; } - std::uint64_t GetRequestID() { return request_id_; } - void ClearRequestID(); - - void AddTrace(const std::uint32_t layer_id, const std::vector buffer); - std::vector GetCandidates(std::uint32_t layer_id); - -private: - std::uint64_t request_id_; - std::unordered_map>>> - traces_; - std::uint32_t max_layer_id_ = 0; -}; diff --git a/core/trace/model_topology.cpp b/core/trace/model_topology.cpp deleted file mode 100644 index 54d33e2..0000000 --- a/core/trace/model_topology.cpp +++ /dev/null @@ -1,699 +0,0 @@ -// Copyright (c) TorchMoE. -// SPDX-License-Identifier: Apache-2.0 - -// TorchMoE Team - -#include "model_topology.h" - -#include -#include -#include -#include -#include -#include -#include "aio/archer_prio_aio_handle.h" -#include "aio/archer_tensor_handle.h" -#include "aio/archer_tensor_index.h" -#include "common/time.h" -#include "common/types.h" -#include "memory/memory_pool.h" -#include "memory/stream_pool.h" -#include "parallel/expert_dispatcher.h" -#include "prefetch/task_scheduler.h" -#include "utils/archer_logger.h" - -cudaStream_t kCudaStreamH2D = NULL; -std::unique_ptr kTopologyHandle = nullptr; - -const std::string Node::str() noexcept -{ - // write same string using c style sprintf - std::stringstream ss; - for (auto& tensor_id : tensor_ids) { ss << tensor_id << ","; } - - char buffer[1024]; - memset(buffer, 0, 1024); - sprintf(buffer, - "ID[%ld,%lx] (%ldMB) STATE(%d) TENSOR[%s] DEVICE[%s;%s;%s];", - id, - corr_id, - byte_size / MB, - state.load(), - ss.str().c_str(), - device.str().c_str(), - default_device.str().c_str(), - default_host.str().c_str()); - - return std::string(buffer); -} - -Node::Node() - : corr_id(0), - byte_size(0), - last_access_time(MCIROSECONDS_SINCE_EPOCH), - device(DISK_DEVICE), - default_device(DEFAULT_CUDA_DEVICE) -{ -} - -void Node::SetDevice(const torch::Device& target_device, - bool on_demand, - cudaStream_t stream) noexcept -{ - ARCHER_LOG_DEBUG("SetDevice: ", str(), " to ", target_device.str()); - if (device == target_device) { - ARCHER_LOG_DEBUG("SetDevice: " + str() + " to " + target_device.str() + - " but device is the same"); - return; - } - - if (device.type() == target_device.type()) { - ARCHER_LOG_WARN("SetDevice: " + str() + " to " + target_device.str() + - " but device type is the same"); - return; - } - - if (kCudaStreamH2D == NULL) { - auto cudaError = cudaStreamCreateWithFlags(&kCudaStreamH2D, cudaStreamNonBlocking); - if (cudaError != cudaSuccess) { - ARCHER_LOG_ERROR("cudaStreamCreate failed: ", cudaGetErrorString(cudaError)); - exit(-1); - } - } - - if (target_device == DISK_DEVICE) { - SetModuleDisk(tensor_ids); - if (host_memory_ptr != nullptr) { - kHostMemoryPool->FreeMemory(id, host_memory_ptr, byte_size, CPU_DEVICE); - host_memory_ptr = nullptr; - } - if (device_memory_ptr != nullptr) { - kDeviceMemoryPool->FreeMemory(id, device_memory_ptr, byte_size, device); - device_memory_ptr = nullptr; - } - } else { - // both are null, which means the node is not initialized - if (host_memory_ptr == nullptr && device_memory_ptr == nullptr) { - // int numa_id = - // default_device.index() / 4; // TODO: 8 gpus, 2 numa nodes, so 4 gpus per numa - host_memory_ptr = kHostMemoryPool->AllocateMemory(id, byte_size, CPU_DEVICE); - assert(host_memory_ptr != nullptr); - - auto start_time = MCIROSECONDS_SINCE_EPOCH; - SetModuleMemoryFromDisk(tensor_ids, host_memory_ptr, on_demand); - auto end_time = MCIROSECONDS_SINCE_EPOCH; - ARCHER_LOG_DEBUG("SetModuleMemoryFromDisk time:", end_time - start_time, " us"); - } - - if (target_device.is_cuda()) { - // ARCHER_LOG_DEBUG("Allocate GPU Memory for node {}", this->id); - device_memory_ptr = kDeviceMemoryPool->AllocateMemory(id, byte_size, target_device); - // ARCHER_LOG_DEBUG("Allocate GPU Memory for node {} done", this->id); - assert(device_memory_ptr != nullptr); - assert(host_memory_ptr != nullptr); - - auto start_time = MCIROSECONDS_SINCE_EPOCH; - if (stream == nullptr) { - cudaMemcpy(device_memory_ptr, host_memory_ptr, byte_size, cudaMemcpyHostToDevice); - } else { - cudaMemcpyAsync( - device_memory_ptr, host_memory_ptr, byte_size, cudaMemcpyHostToDevice, stream); - cudaStreamSynchronize(stream); - } - SetModuleCudaMemoryFromCPU(tensor_ids, device_memory_ptr, target_device); - auto end_time = MCIROSECONDS_SINCE_EPOCH; - ARCHER_LOG_DEBUG("SetModuleCudaMemoryFromCPU time: {} us", end_time - start_time); - } - - if (target_device.is_cpu() && device.is_cuda()) { - assert(host_memory_ptr != nullptr); - auto start_time = MCIROSECONDS_SINCE_EPOCH; - SetModuleMemoryFromCuda(tensor_ids, host_memory_ptr); - kDeviceMemoryPool->FreeMemory(id, device_memory_ptr, byte_size, device); - device_memory_ptr = nullptr; - auto end_time = MCIROSECONDS_SINCE_EPOCH; - ARCHER_LOG_DEBUG("SetModuleMemoryFromCuda time: {} us", end_time - start_time); - } - } - device = target_device; -} - -ArcherTopologyHandle::ArcherTopologyHandle() {} - -NodePtrList ArcherTopologyHandle::GetLFUNodes(const torch::Device& device) -{ - NodePtrList nodes; - std::lock_guard lock(mutex_); - for (auto node_body : lfu_nodes_) { - CONTINUE_IF_NULL(node_body); - if (node_body->node->device == device) { nodes.push_back(node_body->node); } - } - return nodes; -} - -NodePtrList ArcherTopologyHandle::GetDenseNodes() -{ - NodePtrList nodes; - for (auto stage : pipeline_.stages) { - if (stage->is_sparse) { continue; } - for (auto node_body : stage->nodes) { nodes.push_back(node_body->node); } - } - return nodes; -} -NodePtrList ArcherTopologyHandle::GetSparseNodes() -{ - NodePtrList nodes; - for (auto stage : pipeline_.stages) { - if (!stage->is_sparse) { continue; } - for (auto node_body : stage->nodes) { nodes.push_back(node_body->node); } - } - return nodes; -} - -NodePtrList ArcherTopologyHandle::GetDenseNodes(const NodePtr& node, const std::size_t& k) -{ - NodePtrList nodes; - - std::size_t low_corr_id = node->corr_id & 0xFFFFFFFF; // stage id - std::size_t high_corr_id = node->corr_id >> 32; // node id - bool is_last_node = (0xFFFFFFFF == high_corr_id); - if (is_last_node) { - high_corr_id = 0; // reset to 0 avoid miss use - } - - std::lock_guard lock(mutex_); - - low_corr_id++; - std::size_t count = 0; - while ((low_corr_id < pipeline_.stages.size()) && (count < k)) { - // Due to MoE design, we only process layer by layer - auto stage = pipeline_.stages[low_corr_id]; - low_corr_id++; - if (stage->is_sparse) { continue; } - - nodes.push_back(stage->nodes[0]->node); - count++; - } - return nodes; -} - -NodePtrList ArcherTopologyHandle::GetSparseNodes(const NodePtr& node, const std::size_t& k) -{ - NodePtrList nodes; - - std::size_t low_corr_id = node->corr_id & 0xFFFFFFFF; // stage id - std::size_t high_corr_id = node->corr_id >> 32; // node id - bool is_last_node = (0xFFFFFFFF == high_corr_id); - if (is_last_node) { - high_corr_id = 0; // reset to 0 avoid miss use - } - - std::lock_guard lock(mutex_); - - low_corr_id++; - std::size_t count = 0; - while ((low_corr_id < pipeline_.stages.size()) && (count < k)) { - // Due to MoE design, we only process layer by layer - auto stage = pipeline_.stages[low_corr_id]; - - low_corr_id++; - if (!stage->is_sparse) { continue; } - - nodes.push_back(stage->nodes[0]->node); - count++; - } - return nodes; -} - -std::uint64_t ArcherTopologyHandle::GetLastActivateStage(const HashID& hash_id) -{ - std::lock_guard lock(mutex_); - auto it = last_active_stage_.find(hash_id); - if (it == last_active_stage_.end()) { return 0; } - return it->second; -} - -std::vector> ArcherTopologyHandle::GetNodeVisitCounts() -{ - std::lock_guard lock(mutex_); - std::vector> node_visit_counts; - for (auto& stage : pipeline_.stages) { - for (auto& node : stage->nodes) { - node->node->io_state = NODE_STATE_NONE; - std::vector metrics{node->visit_cnt, - node->gpu_visit_cnt, - node->cpu_visit_cnt, - node->hit_cnt, - node->gpu_hit_cnt, - node->cpu_hit_cnt, - node->node->tensor_ids.size(), - node->prefetch_cnt, - node->node->unused_count, - node->node->io_state, - node->is_sparse}; - node_visit_counts.push_back(metrics); - } - } - return node_visit_counts; -} - -std::vector ArcherTopologyHandle::GetChildVisitCounts() -{ - std::lock_guard lock(mutex_); - int num_layers = 0; - int num_experts = 0; - for (auto& stage : pipeline_.stages) { - if (stage->is_sparse) { - num_layers += 1; - num_experts = stage->nodes.size(); - } - } - std::vector child_visit_counts((num_layers - 1) * num_experts * num_experts); - int layer_idx = 0; - int parent_idx = 0; - int expert_idx = 0; - for (auto& stage : pipeline_.stages) { - if (stage->is_sparse) { - for (auto& node : stage->nodes) { - if (node->children.size() > 0) { - for (auto& count : node->children_visit_cnt) { - child_visit_counts[layer_idx * num_experts * num_experts + - parent_idx * num_experts + expert_idx] = count; - expert_idx++; - } - } - parent_idx++; - expert_idx = 0; - } - layer_idx++; - parent_idx = 0; - } - } - - return child_visit_counts; -} - -void ArcherTopologyHandle::SetNodeVisitCounts(const std::vector& visit_counts) -{ - std::lock_guard lock(mutex_); - std::size_t num_nodes = 0; - std::size_t num_experts = 0; - for (auto& stage : pipeline_.stages) { - if (stage->is_sparse) { - num_nodes += stage->nodes.size(); - num_experts = stage->nodes.size(); - } - } - if (visit_counts.size() != num_nodes) { - ARCHER_LOG_ERROR( - "visit_counts size {} not equal to num_nodes {}", visit_counts.size(), num_nodes); - return; - } - - int layer_idx = 0; - int expert_idx = 0; - for (auto& stage : pipeline_.stages) { - if (stage->is_sparse) { - for (auto& node : stage->nodes) { - node->visit_cnt = visit_counts[layer_idx * num_experts + expert_idx]; - expert_idx++; - } - layer_idx++; - expert_idx = 0; - } - } - - DisableTrace(); -} -void ArcherTopologyHandle::SetChildVisitCounts(const std::vector& visit_counts) -{ - std::lock_guard lock(mutex_); - std::size_t num_layers = 0; - std::size_t num_experts = 0; - for (auto& stage : pipeline_.stages) { - if (stage->is_sparse) { - num_layers += 1; - num_experts = stage->nodes.size(); - } - } - if (visit_counts.size() != (num_layers - 1) * num_experts * num_experts) { - ARCHER_LOG_ERROR( - "visit_counts size {} not equal to num_layers {}", visit_counts.size(), num_layers); - return; - } - - int layer_idx = 0; - int parent_idx = 0; - int expert_idx = 0; - for (auto& stage : pipeline_.stages) { - if (stage->is_sparse) { - for (auto& node : stage->nodes) { - if (node->children.size() > 0) { - for (auto& count : node->children_visit_cnt) { - count = visit_counts[layer_idx * num_experts * num_experts + - parent_idx * num_experts + expert_idx]; - expert_idx++; - } - } - parent_idx++; - expert_idx = 0; - } - layer_idx++; - parent_idx = 0; - } - } - - DisableTrace(); -} - -bool ArcherTopologyHandle::IsLastNode(const NodePtr& node) -{ - std::lock_guard lock(mutex_); - auto last_stage_ptr = pipeline_.stages.back(); - auto& nodes = last_stage_ptr->nodes; - for (auto& n : nodes) { - if (n->node == node) { return true; } - } - return false; -} -bool ArcherTopologyHandle::IsFirstNode(const NodePtr& node) -{ - std::lock_guard lock(mutex_); - auto first_stage_ptr = pipeline_.stages.front(); - auto& nodes = first_stage_ptr->nodes; - for (auto& n : nodes) { - if (n->node == node) { return true; } - } - return false; -} - -void ArcherTopologyHandle::InitializeTopology( - const std::vector>>>& topology) -{ - std::lock_guard lock(mutex_); - pipeline_.stages.clear(); - std::size_t node_id = 0; - std::size_t layer_id = 0; - std::size_t last_sparse_layer_id = UINT64_MAX; - - size_t num_sparse_layers = 0; - size_t num_experts = 0; - - std::vector all_nodes; - - for (auto& stage : topology) { - auto& stage_tensors = std::get<1>(stage); - auto stage_ptr = std::make_shared(stage_tensors.size() > 1); - - std::size_t expert_id = 0; - for (auto& tensor_ids : stage_tensors) { - auto node_ptr = std::make_shared(); - node_ptr->tensor_ids = tensor_ids; - int64_t byte_size = 0; - for (auto& tensor_id : tensor_ids) { - auto it = kTensorIndex->find(tensor_id); - if (it != kTensorIndex->end()) { - std::int64_t size_aligned = - (it->second.size + kAioAlignment - 1) & ~(kAioAlignment - 1); - byte_size += size_aligned; - } else { - ARCHER_LOG_ERROR("Tensor {} not found in tensor index", tensor_id); - } - } - node_ptr->byte_size = byte_size; - node_ptr->id = node_id; - node_ptr->corr_id = (layer_id & 0xFFFFFFFF) | ((expert_id & 0xFFFFFFFF) << 32); - node_ptr->is_sparse = stage_ptr->is_sparse; - - all_nodes.push_back(node_ptr); - - auto node_body_ptr = std::make_shared(node_ptr); - node_body_ptr->is_sparse = stage_ptr->is_sparse; - - stage_ptr->nodes.push_back(node_body_ptr); - - node_id++; - expert_id++; - } - pipeline_.stages.push_back(stage_ptr); - auto current_layer_id = layer_id; - layer_id++; - - if (stage_ptr->is_sparse) { - if (UINT64_MAX == last_sparse_layer_id) { - last_sparse_layer_id = current_layer_id; - continue; - } - // set node_body_ptr vectors to be the same size as the number of experts - // all counts initialized to 0 - auto last_sparse_stage_ptr = pipeline_.stages[last_sparse_layer_id]; - for (auto& node : last_sparse_stage_ptr->nodes) { - node->children_visit_cnt.resize(stage_ptr->nodes.size(), 0); - node->children = stage_ptr->nodes; - } - last_sparse_layer_id = current_layer_id; - - num_sparse_layers++; - num_experts = stage_ptr->nodes.size(); - } - } - - // set last stage nodes corr_id higher 32 bits to be 0xFFFFFFFF - auto last_stage_ptr = pipeline_.stages.back(); - for (auto& node_body : last_stage_ptr->nodes) { - node_body->node->corr_id = (node_body->node->corr_id & 0xFFFFFFFF) | (UINT64_MAX << 32); - } - - // output every tensor id in node - for (auto& stage : pipeline_.stages) { - for (auto& node : stage->nodes) { - std::stringstream ss; - for (auto& tensor_id : node->node->tensor_ids) { ss << tensor_id << " "; } - // ARCHER_LOG_DEBUG("Node {} tensor ids {}", node->node->id, ss.str()); - lfu_nodes_.push_back(node); - } - } - - ARCHER_LOG_DEBUG("InitializeTopology pipeline_.stages.size() {}", pipeline_.stages.size()); - - // Model placement - auto num_gpu = GetDeviceCount(); - std::vector free_device_mem(num_gpu, 0); - for (int i = 0; i < num_gpu; i++) { - free_device_mem[i] = kDeviceMemoryPool->GetMemoryCapacity(torch::Device(torch::kCUDA, i)); - } - - auto sparse_nodes = GetSparseNodes(); - auto dense_nodes = GetDenseNodes(); - - ARCHER_LOG_DEBUG("InitializeTopology num_gpu {} sparse_nodes.size() {} dense_nodes.size() {}", - num_gpu, - sparse_nodes.size(), - dense_nodes.size()); - - int target_device_id = 0; - int dense_gpu_idx = 0; - int sparse_gpu_idx = 0; - - // Split evently dense nodes only - int num_dense_nodes_per_device = std::ceil(dense_nodes.size() / num_gpu / 2); - // int total_dense_nodes = dense_nodes.size(); - int counter = 0; - for (auto& node_ptr : dense_nodes) { - // split dense node evenly among GPUs - node_ptr->default_device = torch::Device(torch::kCUDA, target_device_id); - counter++; - if (counter % num_dense_nodes_per_device == 0) { - target_device_id = (target_device_id + 1) % num_gpu; - } - } - dense_nodes.back()->default_device = torch::Device(torch::kCUDA, num_gpu - 1); - - // split evenly sparse nodes among GPUs - for (auto& node_ptr : sparse_nodes) { - node_ptr->default_device = torch::Device(torch::kCUDA, target_device_id); - target_device_id = (target_device_id + 1) % num_gpu; - } - - ARCHER_LOG_DEBUG("InitializeTopology pipeline_.stages.size() {}", pipeline_.stages.size()); - - for (auto& node_ptr : all_nodes) { - ARCHER_LOG_DEBUG("Node {} {} device {}", - node_ptr->id, - node_ptr->is_sparse, - node_ptr->default_device.str()); - } - - EnableTrace(); -} - -NodePtr ArcherTopologyHandle::GetNodeFromTensorID(const TensorID& tensor_id) -{ - std::lock_guard lock(mutex_); - - auto it = tensor_id_to_node_.find(tensor_id); - if (it != tensor_id_to_node_.end()) { - return it->second; - } else { - // search in pipeline - for (auto& stage : pipeline_.stages) { - for (auto& node_body : stage->nodes) { - for (auto& id : node_body->node->tensor_ids) { - if (id == tensor_id) { - tensor_id_to_node_[tensor_id] = node_body->node; - return node_body->node; - } - } - } - } - } - ARCHER_LOG_ERROR("Tensor {} not found in tensor id to node map", tensor_id); - return nullptr; -} - -NodeBodyPtr ArcherTopologyHandle::GetNodeBodyFromCorrID(const std::uint64_t& correlation_id) -{ - std::lock_guard lock(mutex_); - - std::uint64_t high_corr_id = correlation_id >> 32; // For children in the same level - std::uint64_t low_corr_id = correlation_id & 0xFFFFFFFF; // For model inference pipeline - - bool is_last_node = (0xFFFFFFFF == high_corr_id); - if (is_last_node) { - high_corr_id = 0; // reset to 0 avoid miss use - } - - auto stage = pipeline_.stages[low_corr_id]; - auto node_body = stage->nodes[high_corr_id]; - - return node_body; -} - -std::int64_t ArcherTopologyHandle::GetSparseCacheLimit(const torch::Device& device) -{ - std::int64_t dense_cache_size = 0; - for (auto& stage : pipeline_.stages) { - for (auto& node_body : stage->nodes) { - if (stage->is_sparse) continue; - if (node_body->node->device == device) { - dense_cache_size += node_body->node->byte_size; - } - } - } - - std::int64_t device_size_limit = (device.is_cuda()) - ? kDeviceMemoryPool->GetMemoryCapacity(device) - : kHostMemoryPool->GetMemoryCapacity(); - assert(device_size_limit > dense_cache_size); - std::int64_t sparse_cache_size = device_size_limit - dense_cache_size; - - return sparse_cache_size; -} - -std::tuple ArcherTopologyHandle::GetNumLayersAndExperts() -{ - std::lock_guard lock(mutex_); - int num_layers = 0; - int num_experts = 0; - for (auto& stage : pipeline_.stages) { - if (stage->is_sparse) { - num_layers += 1; - num_experts = stage->nodes.size(); - } - } - return std::make_tuple(num_layers, num_experts); -} - -// CPU, GPU -> DISK -// Moves tensors from CPU/GPU to disk. -void SetModuleDisk(std::vector& tensor_ids) -{ - // ARCHER_LOG_DEBUG("SetModuleDisk {} tensors", tensor_ids.size()); - for (const auto& tensor_id : tensor_ids) { - // void* old_ptr = kTensorIndex->find(tensor_id)->second.tensor.data_ptr(); - auto it = kTensorIndex->find(tensor_id); - - at::TensorOptions options; - options = options.device(torch::kCPU); - options = options.dtype(it->second.tensor.dtype()); - auto tensor = torch::zeros({1}, options); - it->second.tensor.set_data(tensor); - } -} - -std::mutex kReadMutex; - -// DISK -> CPU -void SetModuleMemoryFromDisk(std::vector& tensor_ids, void* host_ptr, bool on_demand) -{ - std::int64_t param_size = 0; - for (const auto& tensor_id : tensor_ids) { - // void* old_ptr = kTensorIndex->find(tensor_id)->second.tensor.data_ptr(); - kArcherTensorHandle->ReadTensor( - tensor_id, (void*)((char*)host_ptr + param_size), on_demand); - auto it = kTensorIndex->find(tensor_id); - auto options = torch::TensorOptions() - .dtype(it->second.options.dtype()) - .layout(it->second.options.layout()) - .device(torch::kCPU) - .requires_grad(it->second.options.requires_grad()) - .pinned_memory(it->second.options.pinned_memory()); - - ARCHER_LOG_DEBUG("SetModuleMemoryFromDisk tensor {}", it->second.DebugString()); - auto tensor_tmp = torch::from_blob((void*)((char*)host_ptr + param_size), - it->second.shape, - DoNothingDeleter{}, - options); - if (!it->second.tensor.defined()) { it->second.tensor = torch::zeros({1}, options); } - it->second.tensor.set_data(tensor_tmp); - std::int64_t size_aligned = (it->second.size + kAioAlignment - 1) & ~(kAioAlignment - 1); - param_size += size_aligned; - } -} - -// CPU -> GPU -void SetModuleCudaMemoryFromCPU(std::vector& tensor_ids, - void* device_ptr, - const torch::Device& device) -{ - // ARCHER_LOG_DEBUG("SetModuleCudaMemoryFromCPU {} tensors", tensor_ids.size()); - std::int64_t param_size = 0; - for (const auto& tensor_id : tensor_ids) { - auto it = kTensorIndex->find(tensor_id); - ARCHER_LOG_DEBUG( - "SetModuleCudaMemoryFromCPU tensor {} -> {}", it->second.DebugString(), device.str()); - auto tensor_options = torch::TensorOptions() - .dtype(it->second.options.dtype()) - .layout(it->second.options.layout()) - .device(device) - .requires_grad(it->second.options.requires_grad()) - .pinned_memory(false); - it->second.tensor.set_data(torch::from_blob((char*)device_ptr + param_size, - it->second.shape, - DoNothingDeleter{}, - tensor_options)); - std::int64_t size_aligned = (it->second.size + kAioAlignment - 1) & ~(kAioAlignment - 1); - param_size += size_aligned; - } - // ARCHER_LOG_DEBUG("SetModuleCudaMemoryFromCPU {} tensors done", tensor_ids.size()); -} - -// GPU -> CPU -void SetModuleMemoryFromCuda(std::vector& tensor_ids, void* host_ptr) -{ - std::int64_t param_size = 0; - for (const auto& tensor_id : tensor_ids) { - // void* old_ptr = kTensorIndex->find(tensor_id)->second.tensor.data_ptr(); - - auto it = kTensorIndex->find(tensor_id); - ARCHER_LOG_DEBUG("SetModuleMemoryFromCuda tensor ", it->second.DebugString()); - it->second.tensor.set_data(torch::from_blob((char*)host_ptr + param_size, - it->second.shape, - DoNothingDeleter{}, - it->second.options)); - // kArcherTensorHandle->UpdateTensorMap(old_ptr, it->second.tensor.data_ptr()); - std::int64_t size_aligned = (it->second.size + kAioAlignment - 1) & ~(kAioAlignment - 1); - param_size += size_aligned; - } - // ARCHER_LOG_DEBUG("SetModuleMemoryFromCuda {} tensors done", tensor_ids.size()); -} diff --git a/core/trace/model_topology.h b/core/trace/model_topology.h deleted file mode 100644 index 2cd417c..0000000 --- a/core/trace/model_topology.h +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) TorchMoE. -// SPDX-License-Identifier: Apache-2.0 - -// TorchMoE Team - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include "common/pytorch.h" -#include "common/types.h" -#include "memory/memory_pool.h" -#include "utils/noncopyable.h" - -enum NodeState { - NODE_STATE_NONE = 0x0, - NODE_STATE_CACHED = 0x1, - NODE_STATE_PREFETCHED = 0x2, - NODE_STATE_VISITED = 0x4, -}; - -extern cudaStream_t kCudaStreamH2D; - -struct Node { - std::vector tensor_ids; - std::int64_t byte_size; - std::size_t last_access_time; - std::size_t last_prefetch_time = 0; - - std::size_t id; - std::size_t corr_id; - - torch::Device device = DISK_DEVICE; - torch::Device default_device = DEFAULT_CUDA_DEVICE; // FIXME: should be set by scheduler - torch::Device default_host = CPU_DEVICE; - torch::Device initial_host = DISK_DEVICE; - - std::atomic_uint8_t state{0}; // 0 for ready, 1 for moving - - std::mutex mutex; - std::condition_variable cv; - - std::uint64_t visit_count = 0; - std::uint64_t unused_count = 0; - bool is_sparse = false; - NodeState io_state = NODE_STATE_NONE; - - bool is_overflow = false; - - void* host_memory_ptr = nullptr; - void* device_memory_ptr = nullptr; - -public: - explicit Node(); - const std::string str() noexcept; - void SetDevice(const torch::Device& target_device, - bool ondemand = false, - cudaStream_t stream = nullptr) noexcept; -}; - -typedef std::shared_ptr NodePtr; -typedef std::vector NodePtrList; -typedef std::tuple FilterResult; - -struct NodeBody; -typedef std::shared_ptr NodeBodyPtr; - -struct NodeBody { - NodePtr node; - std::vector children; - std::vector children_visit_cnt; - std::unordered_set activate_request; - std::size_t prefetch_cnt = 0; - std::size_t visit_cnt = 0; - std::size_t cpu_visit_cnt = 0; - std::size_t gpu_visit_cnt = 0; - std::size_t hit_cnt = 0; - std::size_t gpu_hit_cnt = 0; - std::size_t cpu_hit_cnt = 0; - std::size_t gpu_miss_cnt = 0; - std::size_t cpu_miss_cnt = 0; - bool is_sparse; - std::deque visit_time; - explicit NodeBody(NodePtr node) : node(node), visit_cnt(0) {} - - std::string str() const noexcept - { - std::stringstream ss; - ss << "NodeBody: " << node->str() << " visit_cnt " << visit_cnt << ", child visit ["; - for (auto& visit : children_visit_cnt) { ss << visit << ","; } - ss << "]"; - return ss.str(); - } -}; - -struct Stage { - bool is_sparse; - std::vector nodes; - std::size_t visit_cnt; - std::int64_t byte_size; - std::deque visit_time; - std::unordered_set activate_request; - Stage() : is_sparse(false), visit_cnt(0), byte_size(0) {} - Stage(bool is_sparse) : is_sparse(is_sparse), visit_cnt(0), byte_size(0) {} - - std::string str() const noexcept - { - char buffer[1024]; - memset(buffer, 0, 1024); - sprintf(buffer, "Stage[%ld,%ld,%d]", nodes.size(), visit_cnt, is_sparse); - return std::string(buffer); - } -}; -typedef std::shared_ptr StagePtr; - -struct Pipeline { - std::vector stages; - std::size_t visit_cnt = 0; - - std::string str() const noexcept - { - std::stringstream ss; - ss << "Pipeline: " << stages.size() << " stages; visit_cnt " << visit_cnt << std::endl; - return ss.str(); - } -}; -typedef std::shared_ptr PipelinePtr; - -class ArcherTopologyHandle : public noncopyable { -public: - DELETE_COPY_AND_ASSIGN(ArcherTopologyHandle); - - ArcherTopologyHandle(); - ~ArcherTopologyHandle() = default; - - bool IsLastNode(const NodePtr& node); - bool IsFirstNode(const NodePtr& node); - - NodePtrList GetLFUNodes(const torch::Device& device); - - NodePtrList GetDenseNodes(const NodePtr& node, const std::size_t& k); - NodePtrList GetSparseNodes(const NodePtr& node, const std::size_t& k); - NodePtrList GetDenseNodes(); - NodePtrList GetSparseNodes(); - - std::uint64_t GetLastActivateStage(const HashID& hash_id); - - void InitializeTopology( - const std::vector>>>& topology); - - void EnableTrace() noexcept { trace_enabled_ = true; } - void DisableTrace() noexcept { trace_enabled_ = false; } - - std::vector> GetNodeVisitCounts(); - std::vector GetChildVisitCounts(); - void SetNodeVisitCounts(const std::vector& visit_counts); - void SetChildVisitCounts(const std::vector& visit_counts); - - NodePtr GetNodeFromTensorID(const TensorID& tensor_id); - NodeBodyPtr GetNodeBodyFromCorrID(const std::uint64_t& correlation_id); - - std::tuple GetNumLayersAndExperts(); - - std::int64_t GetSparseCacheLimit(const torch::Device& device); - - std::size_t GetNumberOfStages() const noexcept { return pipeline_.stages.size(); } - -private: - Pipeline pipeline_; - std::unordered_set visited_; - std::unordered_map last_active_stage_; - std::vector lfu_nodes_; - std::unordered_map request_time_; - std::unordered_map request_trace_; - std::int64_t visit_count_ = 0; - std::mutex mutex_; - bool trace_enabled_ = true; - - std::unordered_map tensor_id_to_node_; -}; - -extern std::unique_ptr kTopologyHandle; - -#define CONTINUE_IF_NULL(node) \ - if (node == nullptr) continue; -#define BREAK_IF_NULL(node) \ - if (node == nullptr) break; - -extern std::mutex kReadMutex; - -void SetModuleDisk(std::vector& tensor_ids); -void SetModuleMemoryFromDisk(std::vector& tensor_ids, - void* host_ptr, - bool on_demand = false); -void SetModuleCudaMemoryFromCPU(std::vector& tensor_ids, - void* device_ptr, - const torch::Device& device); -void SetModuleMemoryFromCuda(std::vector& tensor_ids, void* host_ptr); diff --git a/core/utils/cuda_utils.cpp b/core/utils/cuda_utils.cpp index b2b74c8..0daa473 100644 --- a/core/utils/cuda_utils.cpp +++ b/core/utils/cuda_utils.cpp @@ -39,3 +39,18 @@ std::size_t GetFreeDeviceMemory(int device_id) cudaMemGetInfo(&free_memory, &total_memory); return free_memory; } + +int CudaMemcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind) +{ + return cudaMemcpy(dst, src, count, kind); +} + +int CudaMemcpyAsync(void* dst, + const void* src, + size_t count, + cudaMemcpyKind kind, + cudaStream_t stream) +{ + return cudaMemcpyAsync( + dst, src, count, kind, stream); +} diff --git a/core/utils/cuda_utils.h b/core/utils/cuda_utils.h index 8a7b6fc..0bb1cc5 100644 --- a/core/utils/cuda_utils.h +++ b/core/utils/cuda_utils.h @@ -15,3 +15,10 @@ std::size_t GetFreeDeviceMemory(int device_id); #define DEVICE_CACHE_LIMIT(gid) GetTotalDeviceMemory(gid) * 0.7 #define NUM_DEVICES GetDeviceCount() + +int CudaMemcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind); +int CudaMemcpyAsync(void* dst, + const void* src, + size_t count, + cudaMemcpyKind kind, + cudaStream_t stream = 0); diff --git a/examples/interface_example.py b/examples/interface_example.py index af1cfa7..a4c7b6b 100644 --- a/examples/interface_example.py +++ b/examples/interface_example.py @@ -9,7 +9,7 @@ import argparse import datasets import multiprocessing as mp -from transformers import AutoTokenizer, TextStreamer +from transformers import AutoTokenizer, TextStreamer, LlamaTokenizerFast from moe_infinity import MoE parser = argparse.ArgumentParser() @@ -20,7 +20,10 @@ model_name = args.model_name_or_path.split("/")[-1] -tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) +if "grok" in model_name: + tokenizer = LlamaTokenizerFast.from_pretrained("Xenova/grok-1-tokenizer", trust_remote_code=True) +else: + tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, trust_remote_code=True) streamer = TextStreamer(tokenizer) dataset_name = "tasksource/bigbench" @@ -51,6 +54,8 @@ custom_kwargs = {"forced_bos_token_id": 256057} # translate to French elif "mixtral" in args.model_name_or_path.lower(): custom_kwargs = {"pad_token_id": tokenizer.eos_token_id} +elif "grok" in args.model_name_or_path.lower(): + custom_kwargs = {} else: raise ValueError(f"Model {args.model_name_or_path} not supported") diff --git a/moe_infinity/distributed/expert_executor.py b/moe_infinity/distributed/expert_executor.py index 7fe4da9..b2ba2b1 100644 --- a/moe_infinity/distributed/expert_executor.py +++ b/moe_infinity/distributed/expert_executor.py @@ -29,6 +29,25 @@ def set_expert_dispatcher(self, expert_dispatcher): def set_device_map_manager(self, device_map_manager): self.device_map_manager = device_map_manager + def dispatch_local(self, hidden_states, router_mask, layer_id): + num_expert = router_mask.shape[-1] + expert_count = torch.sum(router_mask.view((-1, num_expert)), dim=0).cpu().numpy().flatten() + + expert_list = np.arange(num_expert).astype(int)[expert_count > 0].tolist() + expected_wait_cnt = len(expert_list) + + self.expert_dispatcher.set_inputs(hidden_states, router_mask) + self.expert_dispatcher.set_expected_queue(expected_wait_cnt) + + total_gpus = torch.cuda.device_count() + for expert_id in expert_list: + gpu_id = expert_id % total_gpus + self.expert_dispatcher.enqueue_expert(layer_id, expert_id, gpu_id, False) + + result = self.expert_dispatcher.wait_expert() + + return result + def dispatch(self, hidden_states, router_mask, layer_id): num_expert = router_mask.shape[-1] expert_count = torch.sum(router_mask.view((-1, num_expert)), dim=0).cpu().numpy().flatten() diff --git a/moe_infinity/entrypoints/big_modeling.py b/moe_infinity/entrypoints/big_modeling.py index 2359980..b570a5e 100644 --- a/moe_infinity/entrypoints/big_modeling.py +++ b/moe_infinity/entrypoints/big_modeling.py @@ -64,7 +64,7 @@ def __init__( f"Please provide a configuration file or create a default one at {default_config_path}." ) config = default_config_path - model_config = AutoConfig.from_pretrained(model_name_or_path) + model_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) architecture = model_config.architectures[0].lower() arch = None @@ -127,6 +127,7 @@ def __init__( "flash_attention_2" if is_flash_attn_available else "eager" ), is_flash_attn_available=is_flash_attn_available, + trust_remote_code=True, ) def _configure_hook(self, input_ids: torch.LongTensor): diff --git a/moe_infinity/models/grok.py b/moe_infinity/models/grok.py index b5e5f96..d276bee 100644 --- a/moe_infinity/models/grok.py +++ b/moe_infinity/models/grok.py @@ -41,6 +41,15 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: # we cast back to the input dtype routing_weights = routing_weights.to(hidden_states.dtype) + router_mask = F.one_hot(selected_experts, num_classes=self.num_experts) + routing_weights_mask = (routing_weights[:, :, None] * router_mask).permute( + 0, 2, 1 + ) + router_mask = router_mask.permute(0, 2, 1) + # assume top-2 here + router_mask = torch.logical_or(router_mask[:, :, 0], router_mask[:, :, 1]) + routing_weights_mask = torch.sum(routing_weights_mask, dim=-1) + expert_index = selected_experts.reshape(batch_size, sequence_length, self.top_k) for i in range(batch_size): seq_id = self.seq_id_list[i] @@ -52,39 +61,46 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: dtype=hidden_states.dtype, device=hidden_states.device, ) - # One hot encode the selected experts to create an expert mask - # this will be used to easily index which expert is going to be sollicitated - expert_mask = torch.nn.functional.one_hot( - selected_experts, num_classes=self.num_experts - ).permute(2, 1, 0) - - # Loop over all available experts in the model and perform the computation on each expert - for expert_idx in range(self.num_experts): - expert_layer = self.experts[expert_idx] - idx, top_x = torch.where(expert_mask[expert_idx]) - - if top_x.shape[0] == 0: - continue - - # in torch it is faster to index using lists than torch tensors - top_x_list = top_x.tolist() - idx_list = idx.tolist() - - # Index the correct hidden states and compute the expert hidden state for - # the current expert. We need to make sure to multiply the output hidden - # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) - # print(f"current hidden_states device: {current_state.device} expert_layer device: {expert_layer}") - current_hidden_states = ( - expert_layer(current_state).to(routing_weights.device) - * routing_weights[top_x_list, idx_list, None] - ) - - # However `index_add_` only support torch tensors for indexing so we'll use - # the `top_x` tensor here. - final_hidden_states.index_add_( - 0, top_x, current_hidden_states.to(hidden_states.dtype) - ) + + results = self.expert_executor.dispatch_local(hidden_states, router_mask, self.layer_id) + for output, _, idx, _ in results: + token_indices = router_mask[:, idx].bool() + final_hidden_states[token_indices, :] += output.to(routing_weights_mask.device) * routing_weights_mask[token_indices, idx][:, None] + + # # One hot encode the selected experts to create an expert mask + # # this will be used to easily index which expert is going to be sollicitated + # expert_mask = torch.nn.functional.one_hot( + # selected_experts, num_classes=self.num_experts + # ).permute(2, 1, 0) + + # # Loop over all available experts in the model and perform the computation on each expert + # for expert_idx in range(self.num_experts): + # expert_layer = self.experts[expert_idx] + # idx, top_x = torch.where(expert_mask[expert_idx]) + + # if top_x.shape[0] == 0: + # continue + + # # in torch it is faster to index using lists than torch tensors + # top_x_list = top_x.tolist() + # idx_list = idx.tolist() + + # # Index the correct hidden states and compute the expert hidden state for + # # the current expert. We need to make sure to multiply the output hidden + # # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + # current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + # # print(f"current hidden_states device: {current_state.device} expert_layer device: {expert_layer}") + # current_hidden_states = ( + # expert_layer(current_state).to(routing_weights.device) + # * routing_weights[top_x_list, idx_list, None] + # ) + + # # However `index_add_` only support torch tensors for indexing so we'll use + # # the `top_x` tensor here. + # final_hidden_states.index_add_( + # 0, top_x, current_hidden_states.to(hidden_states.dtype) + # ) + final_hidden_states = final_hidden_states.reshape( batch_size, sequence_length, hidden_dim ) diff --git a/moe_infinity/models/mixtral.py b/moe_infinity/models/mixtral.py index 0f92a51..65cba18 100644 --- a/moe_infinity/models/mixtral.py +++ b/moe_infinity/models/mixtral.py @@ -86,17 +86,22 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device ) - for expert_idx in range(self.num_experts): - # expert_layer = self.experts[expert_idx] - token_indices = router_mask[:, expert_idx] - current_state = hidden_states[token_indices, :] - - if token_indices.any(): - current_hidden_states = ( - self.experts[expert_idx](current_state).to(routing_weights_mask.device) - * routing_weights_mask[token_indices, expert_idx][:, None] - ) - final_hidden_states[token_indices, :] += current_hidden_states + results = self.expert_executor.dispatch_local(hidden_states, router_mask, self.layer_id) + for output, _, idx, _ in results: + token_indices = router_mask[:, idx].bool() + final_hidden_states[token_indices, :] += output.to(routing_weights_mask.device) * routing_weights_mask[token_indices, idx][:, None] + + # for expert_idx in range(self.num_experts): + # # expert_layer = self.experts[expert_idx] + # token_indices = router_mask[:, expert_idx] + # current_state = hidden_states[token_indices, :] + + # if token_indices.any(): + # current_hidden_states = ( + # self.experts[expert_idx](current_state).to(routing_weights_mask.device) + # * routing_weights_mask[token_indices, expert_idx][:, None] + # ) + # final_hidden_states[token_indices, :] += current_hidden_states final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) diff --git a/moe_infinity/models/nllb_moe.py b/moe_infinity/models/nllb_moe.py index 16d4902..894fdb1 100644 --- a/moe_infinity/models/nllb_moe.py +++ b/moe_infinity/models/nllb_moe.py @@ -68,18 +68,20 @@ def forward(self, expert_matrix = self.expert_predictor.predict(seq_id, expert_index[i], self.layer_id) self.expert_prefetcher.prefetch_experts(self.layer_id, expert_matrix) - # self.expert_prefetcher.prefetch_tensors(self.layer_id, router_mask, - # self.expert_tensor_ids, - # n_tokens) - - for expert_id, expert in self.experts.items(): - idx = int(expert_id.split("_")[-1]) - token_indices = router_mask[:, :, idx].bool() + results = self.expert_executor.dispatch_local(hidden_states, router_mask, self.layer_id) + for output, _, idx, _ in results: + token_indices = router_mask[:, idx].bool() weights = combining_weights[..., idx] + next_states[token_indices] += torch.einsum("b,be->be", weights[token_indices], output.to(weights.device)) - if token_indices.any(): - expert_output = expert(hidden_states[token_indices]).to(weights.device) - next_states[token_indices] += torch.einsum("b,be->be", weights[token_indices], expert_output) + # for expert_id, expert in self.experts.items(): + # idx = int(expert_id.split("_")[-1]) + # token_indices = router_mask[:, :, idx].bool() + # weights = combining_weights[..., idx] + + # if token_indices.any(): + # expert_output = expert(hidden_states[token_indices]).to(weights.device) + # next_states[token_indices] += torch.einsum("b,be->be", weights[token_indices], expert_output) next_states[next_states == 0] = hidden_states[next_states == 0] hidden_states = next_states diff --git a/moe_infinity/models/switch_transformers.py b/moe_infinity/models/switch_transformers.py index 24339f2..38bea47 100644 --- a/moe_infinity/models/switch_transformers.py +++ b/moe_infinity/models/switch_transformers.py @@ -87,15 +87,18 @@ def forward(self, hidden_states): expert_matrix = self.expert_predictor.predict(seq_id, expert_index[i], self.layer_id) self.expert_prefetcher.prefetch_experts(self.layer_id, expert_matrix) - # self.expert_prefetcher.prefetch_tensors(self.layer_id, router_mask, - # self.expert_tensor_ids, - # n_tokens) - - for expert_id, expert in self.experts.items(): - idx = int(expert_id.split("_")[-1]) + + results = self.expert_executor.dispatch_local(hidden_states, router_mask, self.layer_id) + + for output, _, idx, _ in results: token_indices = router_mask[:, :, idx].bool() - if token_indices.any(): - next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.device) + next_states[token_indices] = output.to(next_states.device) + + # for expert_id, expert in self.experts.items(): + # idx = int(expert_id.split("_")[-1]) + # token_indices = router_mask[:, :, idx].bool() + # if token_indices.any(): + # next_states[token_indices] = expert(hidden_states[token_indices]).to(next_states.device) hidden_states = router_probs * next_states return hidden_states, (router_logits.to("cuda:0", non_blocking=True), diff --git a/moe_infinity/runtime/model_offload.py b/moe_infinity/runtime/model_offload.py index 0cd1648..0a9ad4a 100644 --- a/moe_infinity/runtime/model_offload.py +++ b/moe_infinity/runtime/model_offload.py @@ -26,6 +26,7 @@ ) from moe_infinity.utils import ArcherConfig from moe_infinity.utils.arguments import copy_args_to_device, copy_kwargs_to_device +from moe_infinity.distributed import DistributedExpertExecutor from moe_infinity.memory import ExpertPrefetcher import moe_infinity @@ -154,7 +155,7 @@ def init( # ): # os.remove(_archer_config.perfect_cache_file) - # self.expert_executor = DistributedExpertExecutor(archer_config=_archer_config) + self.expert_executor = DistributedExpertExecutor(archer_config=_archer_config) # self.expert_prefetcher = ExpertPrefetcher(self.config) # self.device_map_manager = DeviceMapManager(archer_config=_archer_config) @@ -530,9 +531,7 @@ def archer_from_pretrained(cls, *args, **kwargs): # # make unique and sort # layer_idx = sorted(list(set(layer_idx))) - # self.expert_executor.set_expert_dispatcher( - # self.expert_dispatcher - # ) + self.expert_executor.set_expert_dispatcher(self.expert_dispatcher) module_idx = 0 self.expert_layer_modules = [] @@ -551,7 +550,7 @@ def archer_from_pretrained(cls, *args, **kwargs): module.archer_config = self.archer_config # module.expert_dispatcher = self.expert_dispatcher self.expert_modules.append(module) - # module.expert_executor = self.expert_executor + module.expert_executor = self.expert_executor module.expert_prefetcher = self.expert_prefetcher module.expert_tracer = self.expert_tracer module.expert_predictor = self.expert_predictor diff --git a/requirements.txt b/requirements.txt index a2b7fb1..12cf8bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,9 +3,9 @@ ninja packaging>=20.0 py-cpuinfo torch>=2.1.1 -transformers>=4.37.1, <5.0.0 +transformers>=4.37.1, <4.40 sentencepiece -pydantic==1.10.12 +pydantic==1.10.12 datasets>=2.12.0 pyarrow==12.0.0 accelerate