Skip to content

Commit

Permalink
Merge pull request #7 from Liberatedwinner/patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
kyegomez authored Mar 15, 2024
2 parents edf2300 + ea0f66d commit 6c3739a
Showing 1 changed file with 24 additions and 28 deletions.
52 changes: 24 additions & 28 deletions mamba_transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
class RMSNorm(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = dim**-0.5
self.scale = dim ** (-0.5)
self.g = nn.Parameter(torch.ones(dim))

def forward(self, x: Tensor):
def forward(self, x: Tensor) -> Tensor:
return F.normalize(x, dim=-1) * self.scale * self.g


Expand Down Expand Up @@ -97,6 +97,7 @@ def forward(self, x: Tensor) -> Tensor:
x, _, _ = self.attn(x)
x = self.norm(x)
x = self.ffn(x)

return x


Expand Down Expand Up @@ -172,33 +173,28 @@ def __init__(
self.transformer_depth = transformer_depth
self.mamba_depth = mamba_depth

self.mamba_blocks = nn.ModuleList([])
self.transformer_blocks = nn.ModuleList([])
self.ffn_blocks = nn.ModuleList([])

self.mamba_blocks.append(
# Mamba, Transformer, and ffn blocks
self.mamba_blocks = nn.ModuleList([
MambaBlock(dim, mamba_depth, d_state, *args, **kwargs)
)

# Transformer and ffn blocks
for _ in range(depth):
self.ffn_blocks.append(
FeedForward(dim, dim, ff_mult, *args, **kwargs)
)

for _ in range(transformer_depth):
self.transformer_blocks.append(
TransformerBlock(
dim,
heads,
dim_head,
dropout,
ff_mult,
use_linear_attn,
*args,
**kwargs,
)
)
for _ in range(mamba_depth)
])
self.transformer_blocks = nn.ModuleList([
TransformerBlock(
dim,
heads,
dim_head,
dropout,
ff_mult,
use_linear_attn,
*args,
**kwargs,
) for _ in range(transformer_depth)
])

self.ffn_blocks = nn.ModuleList([
FeedForward(dim, dim, ff_mult, *args, **kwargs)
for _ in range(depth)
])

# Layernorm
self.norm = nn.LayerNorm(dim)
Expand Down

0 comments on commit 6c3739a

Please sign in to comment.