Skip to content

Commit

Permalink
align (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
oahzxl authored Feb 24, 2024
1 parent 90e8d6c commit 9db1f1e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
4 changes: 1 addition & 3 deletions train_img.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,6 @@ def main(args):
dtype = torch.bfloat16
elif args.mixed_precision == "fp16":
dtype = torch.float16
elif args.mixed_precision == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unknown mixed precision {args.mixed_precision}")
model: DiT = (
Expand Down Expand Up @@ -283,7 +281,7 @@ def main(args):
parser.add_argument("--log-every", type=int, default=10)
parser.add_argument("--ckpt-every", type=int, default=1000)

parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["bf16", "fp16", "fp32"])
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["bf16", "fp16"])
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
parser.add_argument("--lr", type=float, default=1e-4, help="Gradient clipping value")
parser.add_argument("--grad_checkpoint", action="store_true", help="Use gradient checkpointing")
Expand Down
28 changes: 21 additions & 7 deletions train_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import colossalai
import torch
import torch.distributed as dist
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin
from colossalai.cluster import DistCoordinator
Expand Down Expand Up @@ -49,6 +50,7 @@ def main(args):
model_string_name = args.model.replace("/", "-")
# Create an experiment folder
experiment_dir = f"{args.outputs}/{experiment_index:03d}-{model_string_name}"
dist.barrier()
if coordinator.is_master():
os.makedirs(experiment_dir, exist_ok=True)
with open(f"{experiment_dir}/config.txt", "w") as f:
Expand Down Expand Up @@ -97,7 +99,12 @@ def main(args):

# Create model
img_size = dataset[0][0].shape[-1]
dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
if args.mixed_precision == "bf16":
dtype = torch.bfloat16
elif args.mixed_precision == "fp16":
dtype = torch.float16
else:
raise ValueError(f"Unknown mixed precision {args.mixed_precision}")
model: DiT = (
DiT_models[args.model](
input_size=img_size,
Expand Down Expand Up @@ -196,11 +203,15 @@ def main(args):

# Log loss values:
all_reduce_mean(loss)
if coordinator.is_master() and (step + 1) % args.log_every == 0:
pbar.set_postfix({"loss": loss.item()})
writer.add_scalar("loss", loss.item(), epoch * num_steps_per_epoch + step)
global_step = epoch * num_steps_per_epoch + step
pbar.set_postfix({"loss": loss.item(), "step": step, "global_step": global_step})

# Log to tensorboard
if coordinator.is_master() and (global_step + 1) % args.log_every == 0:
writer.add_scalar("loss", loss.item(), global_step)

if args.ckpt_every > 0 and (step + 1) % args.ckpt_every == 0:
# Save checkpoint
if args.ckpt_every > 0 and (global_step + 1) % args.ckpt_every == 0:
logger.info(f"Saving checkpoint")
save(
booster,
Expand All @@ -210,12 +221,15 @@ def main(args):
lr_scheduler,
epoch,
step + 1,
global_step + 1,
args.batch_size,
coordinator,
experiment_dir,
ema_shape_dict,
)
logger.info(f"Saved checkpoint at epoch {epoch} step {step + 1} to {experiment_dir}")
logger.info(
f"Saved checkpoint at epoch {epoch} step {step + 1} global_step {global_step + 1} to {experiment_dir}"
)

# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(0)
Expand All @@ -242,7 +256,7 @@ def main(args):
parser.add_argument("--batch-size", type=int, default=2)
parser.add_argument("--global-seed", type=int, default=42)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--log-every", type=int, default=50)
parser.add_argument("--log-every", type=int, default=10)
parser.add_argument("--ckpt-every", type=int, default=1000)
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["bf16", "fp16"])
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
Expand Down

0 comments on commit 9db1f1e

Please sign in to comment.