-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
87 lines (70 loc) · 2.7 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Train a diffusion model on amass.
"""
import os
import json
import argparse
from collections import OrderedDict
from guided_diffusion import dist_util, logger
from guided_diffusion.resample import create_named_schedule_sampler
from guided_diffusion.script_util_transformer import (
create_model_condition_and_diffusion,
)
from guided_diffusion.train_util import TrainLoop
from utils import utils_option as option
from data.image_datasets import load_data_amass
def main():
opt = create_opts()
dist_util.setup_dist(devices=opt['gpu_ids'])
logger.configure(dir=opt['path']['root'])
with open(os.path.join(opt['path']['root'], 'options'), 'w') as f:
json.dump(opt, f)
logger.log("creating BoDiffusion...")
model, diffusion = create_model_condition_and_diffusion(
use_fp16=opt['fp16']['use_fp16'],
**opt['ddpm'],
**opt['diffusion'],
)
logger.log("** the model has " + str(sum(p.numel() for p in model.parameters() if p.requires_grad)) + ' parameters **')
logger.log("** the blocks have " + str(sum(p.numel() for p in model.blocks.parameters() if p.requires_grad)) + ' parameters **')
model.to(dist_util.dev())
if opt['fp16']['use_fp16']:
model.convert_to_fp16()
schedule_sampler = create_named_schedule_sampler(opt['train']['schedule_sampler'], diffusion)
logger.log("creating data loader...")
data = load_data_amass(
opt=opt,
class_cond=opt['ddpm']['class_cond'],
joint_cond=True,
joint_cond_L=True,
)
logger.log("training...")
TrainLoop(
model=model,
diffusion=diffusion,
data=data,
batch_size=opt['datasets']['train']['dataloader_batch_size'],
microbatch=opt['datasets']['train']['dataloader_microbatch'],
lr=opt['train']['lr'],
ema_rate=opt['train']['ema_rate'],
log_interval=opt['train']['log_interval'],
save_interval=opt['train']['save_interval'],
resume_checkpoint=opt['path']['resume_checkpoint'],
use_fp16=opt['fp16']['use_fp16'],
fp16_scale_growth=opt['fp16']['fp16_scale_growth'],
schedule_sampler=schedule_sampler,
weight_decay=opt['train']['weight_decay'],
lr_anneal_steps=opt['train']['lr_anneal_steps'],
).run_loop()
def create_opts():
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, default='options/train.json', help='Path to option JSON file.')
json_str = ''
with open(parser.parse_args().opt, 'r') as f:
for line in f:
line = line.split('//')[0] + '\n'
json_str += line
opt = json.loads(json_str, object_pairs_hook=OrderedDict)
return opt
if __name__ == "__main__":
main()