Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
xly committed Apr 25, 2024
1 parent 8ef1206 commit fb780d5
Show file tree
Hide file tree
Showing 23 changed files with 2,397 additions and 98 deletions.
2 changes: 2 additions & 0 deletions examples/interface_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
custom_kwargs = {"forced_bos_token_id": 256057} # translate to French
elif "mixtral" in args.model_name_or_path.lower():
custom_kwargs = {"pad_token_id": tokenizer.eos_token_id}
elif "arctic" in args.model_name_or_path.lower():
custom_kwargs = {"pad_token_id": tokenizer.eos_token_id}
else:
raise ValueError(f"Model {args.model_name_or_path} not supported")

Expand Down
1 change: 1 addition & 0 deletions moe_infinity/common/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .constants import *
43 changes: 43 additions & 0 deletions moe_infinity/common/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from transformers import (
SwitchTransformersForConditionalGeneration,
NllbMoeForConditionalGeneration,
MixtralForCausalLM,
OPTForCausalLM,
PretrainedConfig,
)

from ..models.modeling_grok.modeling_grok1 import Grok1ModelForCausalLM # TODO: Replace this with huggingface transformers
from ..models.modeling_arctic import ArcticForCausalLM # TODO: Replace this with huggingface transformers

MODEL_MAPPING_NAMES = {
"switch": SwitchTransformersForConditionalGeneration,
"nllb": NllbMoeForConditionalGeneration,
"mixtral": MixtralForCausalLM,
"opt": OPTForCausalLM,
"grok": Grok1ModelForCausalLM,
"arctic": ArcticForCausalLM,
}

MODEL_MAPPING_TYPES = {
"switch": 0,
"nllb": 2,
"mixtral": 4,
"grok": 4,
"arctic": 4,
}

def parse_expert_type(config: PretrainedConfig) -> int:
architecture = config.architectures[0].lower()
arch = None
for supp_arch in MODEL_MAPPING_NAMES:
if supp_arch in architecture:
arch = supp_arch
break
if arch is None:
raise RuntimeError(
f"The `load_checkpoint_and_dispatch` function does not support the architecture {architecture}. "
f"Please provide a model that is supported by the function. "
f"Supported architectures are {list(MODEL_MAPPING_NAMES.keys())}."
)

return MODEL_MAPPING_TYPES[arch]
15 changes: 12 additions & 3 deletions moe_infinity/entrypoints/big_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
from accelerate import init_empty_weights

from accelerate.utils.versions import is_torch_version
from moe_infinity.utils.constants import MODEL_MAPPING_NAMES
from moe_infinity.common.constants import MODEL_MAPPING_NAMES
from moe_infinity.runtime import OffloadEngine
from moe_infinity.utils import get_checkpoint_paths, ArcherConfig
from moe_infinity.models import apply_rotary_pos_emb
import moe_infinity
from moe_infinity.models.modeling_arctic import ArcticConfig


class MoE:
Expand Down Expand Up @@ -64,7 +65,10 @@ def __init__(
f"Please provide a configuration file or create a default one at {default_config_path}."
)
config = default_config_path
model_config = AutoConfig.from_pretrained(model_name_or_path)
if "arctic" in model_name_or_path:
model_config = ArcticConfig.from_pretrained(model_name_or_path)
else:
model_config = AutoConfig.from_pretrained(model_name_or_path)
architecture = model_config.architectures[0].lower()

arch = None
Expand Down Expand Up @@ -136,7 +140,12 @@ def _configure_hook(self, input_ids: torch.LongTensor):
)

if self.arch == "grok":
moe_infinity.modeling_grok.modeling_grok1.apply_rotary_pos_emb = (
moe_infinity.models.modeling_grok.modeling_grok1.apply_rotary_pos_emb = (
apply_rotary_pos_emb
)

if self.arch == "arctic":
moe_infinity.models.modeling_arctic.modeling_arctic.apply_rotary_pos_emb = (
apply_rotary_pos_emb
)

Expand Down
Empty file.
4 changes: 3 additions & 1 deletion moe_infinity/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,7 @@

from .switch_transformers import SyncSwitchTransformersSparseMLP
from .nllb_moe import SyncNllbMoeSparseMLP
from .mixtral import SyncMixtralSparseMoeBlock, apply_rotary_pos_emb
from .mixtral import SyncMixtralSparseMoeBlock
from .grok import SyncGrokMoeBlock
from .arctic import SyncArcticMoeBlock, ArcticConfig
from .model_utils import apply_rotary_pos_emb
61 changes: 61 additions & 0 deletions moe_infinity/models/arctic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Dict, Optional, Tuple
import torch
import torch.nn.functional as F
import torch.nn as nn
from .modeling_arctic import ArcticConfig
from .modeling_arctic import ArcticMLP

from moe_infinity.utils import ArcherConfig
from .model_utils import apply_rotary_pos_emb

class SyncArcticMoeBlock(nn.Module):
archer_config: ArcherConfig = None
layer_id: int = None

def __init__(self, config: ArcticConfig, layer_id: int, **kwargs):
super().__init__()

self.hidden_dim = config.hidden_size
self.num_experts = config.num_local_experts
self.layer_id = layer_id
self.top_k = config.num_experts_per_tok
self.is_moe_layer = (layer_id+1) % config.moe_layer_frequency == 0

self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.experts = nn.ModuleList([ArcticMLP(config) for i in range(self.num_experts)])

self.archer_tracer = None
self.archer_engine = None
self.expert_tensor_ids: Dict[int, int] = None


def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor]:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)

routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(
routing_weights, self.top_k, dim=-1
)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)

expert_index = selected_experts.reshape(batch_size, sequence_length, self.top_k)
for i in range(batch_size):
seq_id = self.seq_id_list[i]
expert_matrix = self.expert_predictor.predict(seq_id, expert_index[i], self.layer_id)
self.expert_prefetcher.prefetch_experts(self.layer_id, expert_matrix)

final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(
selected_experts, num_classes=self.num_experts
).permute(2, 1, 0)
return final_hidden_states, expert_mask
6 changes: 4 additions & 2 deletions moe_infinity/models/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from ..modeling_grok.configuration_grok1 import Grok1Config
from ..modeling_grok.modeling_grok1 import MoeBlock, MoeMLP, rotate_half
from .modeling_grok import Grok1Config
from .modeling_grok import MoeBlock, MoeMLP


from moe_infinity.utils import ArcherConfig
from .model_utils import apply_rotary_pos_emb



class SyncGrokMoeBlock(nn.Module):
Expand Down
13 changes: 1 addition & 12 deletions moe_infinity/models/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,14 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
import transformers
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import (
MixtralBLockSparseTop2MLP,
rotate_half,
)

from moe_infinity.utils import ArcherConfig

def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
device = position_ids.device
position_ids = position_ids.to(cos.device)
cos = cos[position_ids].unsqueeze(unsqueeze_dim).to(q.device)
sin = sin[position_ids].unsqueeze(unsqueeze_dim).to(q.device)
# print("cos.shape", cos.device, "sin.shape", sin.device, "q.shape", q.device, "k.shape", k.device)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
position_ids = position_ids.to(device)
return q_embed, k_embed

class SyncMixtralSparseMoeBlock(nn.Module):
archer_config: ArcherConfig = None
layer_id: int = None
Expand Down
18 changes: 18 additions & 0 deletions moe_infinity/models/model_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
device = position_ids.device
position_ids = position_ids.to(cos.device)
cos = cos[position_ids].unsqueeze(unsqueeze_dim).to(q.device)
sin = sin[position_ids].unsqueeze(unsqueeze_dim).to(q.device)
# print("cos.shape", cos.device, "sin.shape", sin.device, "q.shape", q.device, "k.shape", k.device)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
position_ids = position_ids.to(device)
return q_embed, k_embed
3 changes: 3 additions & 0 deletions moe_infinity/models/modeling_arctic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .modeling_arctic import ArcticForCausalLM, apply_rotary_pos_emb, ArcticMLP, ArcticMoE
from .configuration_arctic import ArcticConfig
from .tokenization_arctic import ArcticTokenizer
Loading

0 comments on commit fb780d5

Please sign in to comment.