-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
xly
committed
Apr 25, 2024
1 parent
8ef1206
commit fb780d5
Showing
23 changed files
with
2,397 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .constants import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from transformers import ( | ||
SwitchTransformersForConditionalGeneration, | ||
NllbMoeForConditionalGeneration, | ||
MixtralForCausalLM, | ||
OPTForCausalLM, | ||
PretrainedConfig, | ||
) | ||
|
||
from ..models.modeling_grok.modeling_grok1 import Grok1ModelForCausalLM # TODO: Replace this with huggingface transformers | ||
from ..models.modeling_arctic import ArcticForCausalLM # TODO: Replace this with huggingface transformers | ||
|
||
MODEL_MAPPING_NAMES = { | ||
"switch": SwitchTransformersForConditionalGeneration, | ||
"nllb": NllbMoeForConditionalGeneration, | ||
"mixtral": MixtralForCausalLM, | ||
"opt": OPTForCausalLM, | ||
"grok": Grok1ModelForCausalLM, | ||
"arctic": ArcticForCausalLM, | ||
} | ||
|
||
MODEL_MAPPING_TYPES = { | ||
"switch": 0, | ||
"nllb": 2, | ||
"mixtral": 4, | ||
"grok": 4, | ||
"arctic": 4, | ||
} | ||
|
||
def parse_expert_type(config: PretrainedConfig) -> int: | ||
architecture = config.architectures[0].lower() | ||
arch = None | ||
for supp_arch in MODEL_MAPPING_NAMES: | ||
if supp_arch in architecture: | ||
arch = supp_arch | ||
break | ||
if arch is None: | ||
raise RuntimeError( | ||
f"The `load_checkpoint_and_dispatch` function does not support the architecture {architecture}. " | ||
f"Please provide a model that is supported by the function. " | ||
f"Supported architectures are {list(MODEL_MAPPING_NAMES.keys())}." | ||
) | ||
|
||
return MODEL_MAPPING_TYPES[arch] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from typing import Dict, Optional, Tuple | ||
import torch | ||
import torch.nn.functional as F | ||
import torch.nn as nn | ||
from .modeling_arctic import ArcticConfig | ||
from .modeling_arctic import ArcticMLP | ||
|
||
from moe_infinity.utils import ArcherConfig | ||
from .model_utils import apply_rotary_pos_emb | ||
|
||
class SyncArcticMoeBlock(nn.Module): | ||
archer_config: ArcherConfig = None | ||
layer_id: int = None | ||
|
||
def __init__(self, config: ArcticConfig, layer_id: int, **kwargs): | ||
super().__init__() | ||
|
||
self.hidden_dim = config.hidden_size | ||
self.num_experts = config.num_local_experts | ||
self.layer_id = layer_id | ||
self.top_k = config.num_experts_per_tok | ||
self.is_moe_layer = (layer_id+1) % config.moe_layer_frequency == 0 | ||
|
||
self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) | ||
self.experts = nn.ModuleList([ArcticMLP(config) for i in range(self.num_experts)]) | ||
|
||
self.archer_tracer = None | ||
self.archer_engine = None | ||
self.expert_tensor_ids: Dict[int, int] = None | ||
|
||
|
||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]: | ||
batch_size, sequence_length, hidden_dim = hidden_states.shape | ||
hidden_states = hidden_states.view(-1, hidden_dim) | ||
# router_logits: (batch * sequence_length, n_experts) | ||
router_logits = self.gate(hidden_states) | ||
|
||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) | ||
routing_weights, selected_experts = torch.topk( | ||
routing_weights, self.top_k, dim=-1 | ||
) | ||
# we cast back to the input dtype | ||
routing_weights = routing_weights.to(hidden_states.dtype) | ||
|
||
expert_index = selected_experts.reshape(batch_size, sequence_length, self.top_k) | ||
for i in range(batch_size): | ||
seq_id = self.seq_id_list[i] | ||
expert_matrix = self.expert_predictor.predict(seq_id, expert_index[i], self.layer_id) | ||
self.expert_prefetcher.prefetch_experts(self.layer_id, expert_matrix) | ||
|
||
final_hidden_states = torch.zeros( | ||
(batch_size * sequence_length, hidden_dim), | ||
dtype=hidden_states.dtype, | ||
device=hidden_states.device, | ||
) | ||
# One hot encode the selected experts to create an expert mask | ||
# this will be used to easily index which expert is going to be sollicitated | ||
expert_mask = torch.nn.functional.one_hot( | ||
selected_experts, num_classes=self.num_experts | ||
).permute(2, 1, 0) | ||
return final_hidden_states, expert_mask |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
|
||
def rotate_half(x): | ||
"""Rotates half the hidden dims of the input.""" | ||
x1 = x[..., : x.shape[-1] // 2] | ||
x2 = x[..., x.shape[-1] // 2 :] | ||
return torch.cat((-x2, x1), dim=-1) | ||
|
||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): | ||
device = position_ids.device | ||
position_ids = position_ids.to(cos.device) | ||
cos = cos[position_ids].unsqueeze(unsqueeze_dim).to(q.device) | ||
sin = sin[position_ids].unsqueeze(unsqueeze_dim).to(q.device) | ||
# print("cos.shape", cos.device, "sin.shape", sin.device, "q.shape", q.device, "k.shape", k.device) | ||
q_embed = (q * cos) + (rotate_half(q) * sin) | ||
k_embed = (k * cos) + (rotate_half(k) * sin) | ||
position_ids = position_ids.to(device) | ||
return q_embed, k_embed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .modeling_arctic import ArcticForCausalLM, apply_rotary_pos_emb, ArcticMLP, ArcticMoE | ||
from .configuration_arctic import ArcticConfig | ||
from .tokenization_arctic import ArcticTokenizer |
Oops, something went wrong.