diff --git a/.gitignore b/.gitignore index 5e057e7c..d4189547 100644 --- a/.gitignore +++ b/.gitignore @@ -9,5 +9,4 @@ outputs data out wandb -*.model *.json diff --git a/README.md b/README.md index 6d27fca2..9c9dbb8e 100644 --- a/README.md +++ b/README.md @@ -6,17 +6,15 @@ torchtrain contains PyTorch native parallelisms, tools and utilities to train la # Installation -install PyTorch from source or install the latest pytorch nightly, then install requirements by +Install PyTorch from source or install the latest pytorch nightly, then install requirements by ```python pip install -r requirements.txt ``` -download tokenizer from HF -This part is needed first time if there's no tokenizer locally by run: - +Install additional dev requirements if you want to contribute to the repo: ``` -python torchtrain/datasets/download_tokenizer.py --hf_token your_token +pip install -r dev-requirements.txt ``` run the llama debug model locally to verify the setup is correct: diff --git a/torchtrain/datasets/tokenizer/tokenizer.model b/torchtrain/datasets/tokenizer/tokenizer.model new file mode 100644 index 00000000..22bccbcb Binary files /dev/null and b/torchtrain/datasets/tokenizer/tokenizer.model differ diff --git a/torchtrain/models/llama/__init__.py b/torchtrain/models/llama/__init__.py index 8b70ce3d..c1f87f89 100644 --- a/torchtrain/models/llama/__init__.py +++ b/torchtrain/models/llama/__init__.py @@ -6,9 +6,11 @@ __all__ = ["Transformer"] llama_configs = { - "debugmodel": ModelArgs(dim=256, n_layers=1, n_heads=16), + "debugmodel": ModelArgs(dim=256, n_layers=2, n_heads=16), + "1B": ModelArgs(dim=1024, n_layers=16, n_heads=8), "7B": ModelArgs(dim=4096, n_layers=32, n_heads=32), "13B": ModelArgs(dim=5120, n_layers=40, n_heads=40), + "40B": ModelArgs(dim=5120, n_layers=80, n_heads=40), "70B": ModelArgs( dim=8192, n_layers=80, diff --git a/torchtrain/models/llama/model.py b/torchtrain/models/llama/model.py index 7485d32b..aade2940 100644 --- a/torchtrain/models/llama/model.py +++ b/torchtrain/models/llama/model.py @@ -8,6 +8,8 @@ import torch.nn.functional as F from torch import nn +from torchtrain.logging_utils import rank0_log + @dataclass class ModelArgs: @@ -165,7 +167,7 @@ class Attention(nn.Module): Multi-head attention module. Args: - args (ModelArgs): Model configuration parameters. + model_args (ModelArgs): Model configuration arguments. Attributes: n_kv_heads (int): Number of key and value heads. @@ -182,18 +184,39 @@ class Attention(nn.Module): """ - def __init__(self, args: ModelArgs): - + def __init__(self, model_args: ModelArgs): super().__init__() - self.n_heads = args.n_heads - self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) self.n_rep = self.n_heads // self.n_kv_heads - self.head_dim = args.dim // args.n_heads + self.head_dim = model_args.dim // model_args.n_heads + + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) - self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) - self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) - self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) + def reset_parameters(self, init_std): + for item in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_( + item.weight, + mean=0.0, + std=0.02, + ) + + nn.init.trunc_normal_( + self.wo.weight, + mean=0.0, + std=init_std, + ) def forward( self, @@ -269,7 +292,6 @@ def __init__( multiple_of: int, ffn_dim_multiplier: Optional[float], ): - super().__init__() hidden_dim = int(2 * hidden_dim / 3) # custom dim factor multiplier @@ -284,24 +306,38 @@ def __init__( def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) + def reset_parameters(self, init_std): + nn.init.trunc_normal_( + self.w1.weight, + mean=0.0, + std=0.02, + ) + + for item in (self.w2, self.w3): + nn.init.trunc_normal_( + item.weight, + mean=0.0, + std=init_std, + ) + class RotaryEmbedding(nn.Module): """ RotaryEmbedding Module """ - def __init__(self, params: ModelArgs): + def __init__(self, model_args: ModelArgs): super().__init__() - self.params = params - self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) + self.model_args = model_args + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) self.freqs_cis = precompute_freqs_cis( - # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation + # Note that self.model_args.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation # of models is 4096. # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training # or fine-tuning. - self.params.dim // self.params.n_heads, - self.params.max_seq_len * 2, + self.model_args.dim // self.model_args.n_heads, + self.model_args.max_seq_len * 2, ) def forward(self, tokens: torch.Tensor): @@ -327,7 +363,7 @@ class TransformerBlock(nn.Module): Args: layer_id (int): Identifier for the layer. - args (ModelArgs): Model configuration parameters. + model_args (ModelArgs): Model configuration arguments. Attributes: n_heads (int): Number of attention heads. @@ -341,21 +377,22 @@ class TransformerBlock(nn.Module): """ - def __init__(self, layer_id: int, args: ModelArgs): - + def __init__(self, layer_id: int, model_args: ModelArgs): super().__init__() - self.n_heads = args.n_heads - self.dim = args.dim - self.attention = Attention(args) + self.n_heads = model_args.n_heads + self.dim = model_args.dim + self.attention = Attention(model_args) self.feed_forward = FeedForward( - dim=args.dim, - hidden_dim=4 * args.dim, - multiple_of=args.multiple_of, - ffn_dim_multiplier=args.ffn_dim_multiplier, + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, ) self.layer_id = layer_id - self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) - self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) + self.num_layers = model_args.n_layers + self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 def forward( self, @@ -377,16 +414,24 @@ def forward( out = h + self.feed_forward(self.ffn_norm(h)) return out + def reset_parameters(self): + """reset params and norms for entire block""" + self.attention_norm.reset_parameters() + self.ffn_norm.reset_parameters() + + self.attention.reset_parameters(self.weight_init_std) + self.feed_forward.reset_parameters(self.weight_init_std) + class Transformer(nn.Module): """ Transformer Module Args: - params (ModelArgs): Model configuration parameters. + model_args (ModelArgs): Model configuration arguments. Attributes: - params (ModelArgs): Model configuration parameters. + model_args (ModelArgs): Model configuration arguments. vocab_size (int): Vocabulary size. n_layers (int): Number of layers in the model. tok_embeddings (ParallelEmbedding): Token embeddings. @@ -397,21 +442,42 @@ class Transformer(nn.Module): """ - def __init__(self, params: ModelArgs): - + def __init__(self, model_args: ModelArgs): super().__init__() - self.params = params - self.vocab_size = params.vocab_size - self.n_layers = params.n_layers + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + self.model_dim = model_args.dim - self.embeddings = RotaryEmbedding(params) + self.embeddings = RotaryEmbedding(model_args) self.layers = torch.nn.ModuleList() - for layer_id in range(params.n_layers): - self.layers.append(TransformerBlock(layer_id, params)) + for layer_id in range(model_args.n_layers): + self.layers.append(TransformerBlock(layer_id, model_args)) + + self.norm = RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) - self.norm = RMSNorm(params.dim, eps=params.norm_eps) - self.output = nn.Linear(params.dim, params.vocab_size, bias=False) + # init model weights + self.reset_parameters() + rank0_log(f"Model built with: {self.model_args}") + + def reset_parameters( + self, + ): + for layer in self.layers: + layer.reset_parameters() + self.norm.reset_parameters() + final_out_std = self.model_dim**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + rank0_log("Model fully initialized via reset_params") def forward(self, tokens: torch.Tensor): """ @@ -426,7 +492,7 @@ def forward(self, tokens: torch.Tensor): """ h, freqs_cis = self.embeddings(tokens) # fold batch and sequence dimension for more efficient allgather/reduce_scatter - h = h.view(-1, self.params.dim) + h = h.view(-1, self.model_args.dim) for layer in self.layers: h = layer(h, freqs_cis) @@ -435,17 +501,17 @@ def forward(self, tokens: torch.Tensor): # unfold batch and sequence dimension bsz = tokens.shape[0] bs_seqlen = h.shape[0] - h = h.view(bsz, bs_seqlen // bsz, self.params.dim) + h = h.view(bsz, bs_seqlen // bsz, self.model_args.dim) output = self.output(h).float() return output @classmethod - def from_model_args(cls, model_args: ModelArgs): + def from_model_args(cls, model_args: ModelArgs) -> "Transformer": """ Initialize a Transformer model from a ModelArgs object. Args: - model_args (ModelArgs): Model configuration parameters. + model_args (ModelArgs): Model configuration arguments. Returns: Transformer: Transformer model. diff --git a/train.py b/train.py index 1f0e9f75..a0a0f891 100644 --- a/train.py +++ b/train.py @@ -79,6 +79,7 @@ def main(args): world_mesh = parallel_dims.build_mesh(device_type="cuda") model_name = args.model + rank0_log(f"Building {model_name}") # build tokenizer tokenizer_type = model_name_to_tokenizer[model_name] tokenizer = create_tokenizer(tokenizer_type, args.tokenizer_path) @@ -233,7 +234,7 @@ def main(args): parser.add_argument( "--warmup_pct", type=float, - default=0.10, + default=0.20, help="percentage of total training steps to use for warmup", ) parser.add_argument(