Skip to content

Commit

Permalink
add configurable unique layer init, clean up lr and loss display (#64)
Browse files Browse the repository at this point in the history
Small PR:
1 - add configurable init style in model_args - 'use_unique_init' will
use the layer_id in the init stddev denom, otherwise uses the original
init style of total layer count. (verified both work on 7B llama...not
clear yet if one is better vs other).

2 - clean up lr and loss display formatting - lr display was spanning
out to 12+ digits which isn't that informative, and was wrapped in list
format. This PR rounds it to max of 8 digits precision and removes the
[]'s that were around the lr rate display.
(note this is purely UI...the full float precision is still used in
actual lr calcs).

3 - clean up loss display - rounds the loss display to 4 digits
precision to make it more readable and informative.
previously:
<img width="1198" alt="Screenshot 2024-02-16 at 2 33 34 PM"
src="https://github.com/pytorch-labs/torchtrain/assets/46302957/77733af0-42db-4fab-a047-fccc7d404278">

Now:
<img width="1063" alt="Screenshot 2024-02-16 at 2 51 53 PM"
src="https://github.com/pytorch-labs/torchtrain/assets/46302957/4eb75b98-67f4-41ec-83d8-dd84a0e8b29e">
  • Loading branch information
lessw2020 authored Feb 22, 2024
1 parent 70be86e commit 78878f5
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 7 deletions.
9 changes: 8 additions & 1 deletion torchtrain/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class ModelArgs:

max_batch_size: int = 32
max_seq_len: int = 32768
depth_init: bool = (
True # initialization uses each unique layer_id or total model layer count
)


class RMSNorm(torch.nn.Module):
Expand Down Expand Up @@ -392,7 +395,11 @@ def __init__(self, layer_id: int, model_args: ModelArgs):
self.num_layers = model_args.n_layers
self.attention_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.ffn_norm = RMSNorm(model_args.dim, eps=model_args.norm_eps)
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5

if model_args.depth_init:
self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5
else:
self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5

def forward(
self,
Expand Down
15 changes: 9 additions & 6 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,14 @@ def main(args):

# log metrics
if (train_state.step - 1) % args.log_freq == 0:
avg_loss, max_loss = np.mean(losses_since_last_log), np.max(
losses_since_last_log
avg_loss, max_loss = (
np.mean(losses_since_last_log),
np.max(losses_since_last_log),
)
global_avg_loss, global_max_loss = (
dist_mean(avg_loss, world_mesh),
dist_max(max_loss, world_mesh),
)
global_avg_loss, global_max_loss = dist_mean(
avg_loss, world_mesh
), dist_max(max_loss, world_mesh)

time_delta = timer() - time_last_log
wps = nwords_since_last_log / (
Expand All @@ -239,7 +241,8 @@ def main(args):
time_last_log = timer()

rank0_log(
f"step: {train_state.step}, current loss: {train_state.current_loss}, lr: {scheduler.get_last_lr()}"
f"step: {train_state.step}, current loss: {round(train_state.current_loss,4)},"
f" lr: {round(float(scheduler.get_last_lr()[0]), 8)}"
)
scheduler.step()

Expand Down

0 comments on commit 78878f5

Please sign in to comment.