Skip to content

Latest commit

 

History

History
33 lines (23 loc) · 579 Bytes

File metadata and controls

33 lines (23 loc) · 579 Bytes

Multi-Head Latent Attention (MLA)

MLA

Quick Start

install ohara

pip install ohara

To train MLA:

python train_mla.py --attn_type=mla

For baseline, use MHA:

python train_mla.py --attn_type=mha 

If you cant to calculate the number of parameters, and check what % kv cache you'll save visite this link: https://joey00072.github.io/Multi-Head-Latent-Attention-MLA-/

TODO

  • write blog post
  • add jax version
  • Add GQA and MOQ in calculation (index.html)
  • Distill llama to MLA version Maybe