Skip to content

Commit

Permalink
[FEAT][Customization]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 13, 2024
1 parent ea8af7a commit 46eddb7
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 6 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ model = MambaTransformer(
d_state=512, # Dimension of the state
dropout=0.1, # Dropout rate
ff_mult=4, # Multiplier for the feed-forward layer dimension
return_embeddings=False, # Whether to return the embeddings,
transformer_depth=2, # Number of transformer blocks
mamba_depth=10, # Number of Mamba blocks
)

# Pass the input tensor through the model and print the output shape
Expand Down
3 changes: 3 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
d_state=512, # Dimension of the state
dropout=0.1, # Dropout rate
ff_mult=4, # Multiplier for the feed-forward layer dimension
return_embeddings=False, # Whether to return the embeddings,
transformer_depth=2, # Number of transformer blocks
mamba_depth=10, # Number of Mamba blocks
)

# Pass the input tensor through the model and print the output shape
Expand Down
26 changes: 20 additions & 6 deletions mamba_transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def __init__(
dropout: float = 0.1,
ff_mult: int = 4,
d_state: int = None,
transformer_depth: int = 1,
mamba_depth: int = 1,
*args,
**kwargs,
):
Expand All @@ -149,17 +151,24 @@ def __init__(
self.dropout = dropout
self.ff_mult = ff_mult
self.d_state = d_state
self.transformer_depth = transformer_depth
self.mamba_depth = mamba_depth

self.mamba_blocks = nn.ModuleList([])
self.transformer_blocks = nn.ModuleList([])
self.ffn_blocks = nn.ModuleList([])

self.mamba_blocks.append(
MambaBlock(dim, depth, d_state, *args, **kwargs)
MambaBlock(dim, mamba_depth, d_state, *args, **kwargs)
)

# Transformer and ffn blocks
for _ in range(depth):
self.ffn_blocks.append(
FeedForward(dim, dim, ff_mult, *args, **kwargs)
)

for _ in range(transformer_depth):
self.transformer_blocks.append(
MultiQueryTransformerBlock(
dim,
Expand All @@ -172,10 +181,6 @@ def __init__(
)
)

self.ffn_blocks.append(
FeedForward(dim, dim, ff_mult, *args, **kwargs)
)

# Layernorm
self.norm = nn.LayerNorm(dim)

Expand Down Expand Up @@ -210,7 +215,7 @@ class MambaTransformer(nn.Module):
d_state (int, optional): The dimensionality of the state embeddings. Defaults to None.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Examples:
>>> import torch
>>> from mt import MambaTransformer
Expand All @@ -228,6 +233,7 @@ class MambaTransformer(nn.Module):
>>> print(model(x).shape)
torch.Size([1, 10, 100])
"""

def __init__(
self,
num_tokens: int,
Expand All @@ -239,6 +245,8 @@ def __init__(
ff_mult: int = 4,
d_state: int = None,
return_embeddings: bool = False,
transformer_depth: int = 1,
mamba_depth: int = 1,
*args,
**kwargs,
):
Expand All @@ -251,6 +259,8 @@ def __init__(
self.ff_mult = ff_mult
self.d_state = d_state
self.return_embeddings = return_embeddings
self.transformer_depth = transformer_depth
self.mamba_depth = mamba_depth

self.emb = nn.Embedding(num_tokens, dim)
self.mt_block = MambaTransformerblock(
Expand All @@ -261,6 +271,9 @@ def __init__(
dropout,
ff_mult,
d_state,
return_embeddings,
transformer_depth,
mamba_depth,
*args,
**kwargs,
)
Expand All @@ -283,5 +296,6 @@ def forward(self, x: Tensor) -> Tensor:

if self.return_embeddings:
return x

else:
return self.to_logits(x)

0 comments on commit 46eddb7

Please sign in to comment.