-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathmain.py
93 lines (80 loc) · 6.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import yaml
import torch
import random
import argparse
import numpy as np
from loop import loop
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', help='Path to config file', type=str, default='./example_config.yml')
parser.add_argument('--output_path', help='Output directory (will be created)', type=str, default=argparse.SUPPRESS)
parser.add_argument('--gpu', help='GPU index', type=int, default=argparse.SUPPRESS)
parser.add_argument('--seed', help='Random seed', type=int, default=argparse.SUPPRESS)
# CLIP-related
parser.add_argument('--text_prompt', help='Target text prompt', type=str, default=argparse.SUPPRESS)
parser.add_argument('--base_text_prompt', help='Base text prompt describing input mesh', type=str, default=argparse.SUPPRESS)
parser.add_argument('--clip_model', help='CLIP Model for text comparison', type=str, default=argparse.SUPPRESS)
parser.add_argument('--consistency_clip_model', help='CLIP Model for consistency', type=str, default=argparse.SUPPRESS)
parser.add_argument('--consistency_vit_stride', help='New stride for ViT patch interpolation', type=int, default=argparse.SUPPRESS)
parser.add_argument('--consistency_vit_layer', help='Which layer to take ViT patch features from (0-11)', type=int, default=argparse.SUPPRESS)
# Mesh
parser.add_argument('--mesh', help='Path to input mesh', type=str, default=argparse.SUPPRESS)
parser.add_argument('--retriangulate', help='Use isotropic remeshing', type=int, default=argparse.SUPPRESS, choices=[0, 1])
# Render settings
parser.add_argument('--bsdf', help='Render technique', type=str, default=argparse.SUPPRESS, choices=['diffuse', 'pbr'])
# Hyper-parameters
parser.add_argument('--lr', help='Learning rate', type=float, default=argparse.SUPPRESS)
parser.add_argument('--epochs', help='Number of optimization steps', type=int, default=argparse.SUPPRESS)
parser.add_argument('--clip_weight', help='Weight for CLIP loss', type=float, default=argparse.SUPPRESS)
parser.add_argument('--delta_clip_weight', help='Wight for delta-CLIP loss', type=float, default=argparse.SUPPRESS)
parser.add_argument('--regularize_jacobians_weight', help='Weight for jacobian regularization', type=float, default=argparse.SUPPRESS)
parser.add_argument('--consistency_loss_weight', help='Weight for viewpoint consistency penalty', type=float, default=argparse.SUPPRESS)
parser.add_argument('--consistency_elev_filter', help='Elev. angle threshold for filtering out pairs of viewpoints for consistency loss', type=float, default=argparse.SUPPRESS)
parser.add_argument('--consistency_azim_filter', help='Azim. angle threshold for filtering out pairs of viewpoints for consistency loss', type=float, default=argparse.SUPPRESS)
parser.add_argument('--batch_size', help='Number of images rendered at the same time', type=int, default=argparse.SUPPRESS)
parser.add_argument('--train_res', help='Resolution of render before downscaling to CLIP size', type=int, default=argparse.SUPPRESS)
parser.add_argument('--resize_method', help='Image downsampling/upsampling method', type=str, default=argparse.SUPPRESS, choices=['cubic', 'linear', 'lanczos2', 'lanczos3'])
## Camera Parameters ##
parser.add_argument('--fov_min', help='Minimum camera field of view angle during renders', type=float, default=argparse.SUPPRESS)
parser.add_argument('--fov_max', help='Maximum camera field of view angle during renders', type=float, default=argparse.SUPPRESS)
parser.add_argument('--dist_min', help='Minimum distance of camera from mesh during renders', type=float, default=argparse.SUPPRESS)
parser.add_argument('--dist_max', help='Maximum distance of camera from mesh during renders', type=float, default=argparse.SUPPRESS)
parser.add_argument('--light_power', help='Light intensity', type=float, default=argparse.SUPPRESS)
parser.add_argument('--elev_alpha', help='Alpha parameter for Beta distribution for elevation sampling', type=float, default=argparse.SUPPRESS)
parser.add_argument('--elev_beta', help='Beta parameter for Beta distribution for elevation sampling', type=float, default=argparse.SUPPRESS)
parser.add_argument('--elev_max', help='Maximum elevation angle in degree', type=float, default=argparse.SUPPRESS)
parser.add_argument('--azim_min', help='Minimum azimuth angle in degree', type=float, default=argparse.SUPPRESS)
parser.add_argument('--azim_max', help='Maximum azimuth angle in degree', type=float, default=argparse.SUPPRESS)
parser.add_argument('--aug_loc', help='Offset mesh from center of image?', type=int, default=argparse.SUPPRESS, choices=[0, 1])
parser.add_argument('--aug_light', help='Augment the direction of light around the camera', type=int, default=argparse.SUPPRESS, choices=[0, 1])
parser.add_argument('--aug_bkg', help='Augment the background', type=int, default=argparse.SUPPRESS, choices=[0, 1])
parser.add_argument('--adapt_dist', help='Adjust camera distance to account for scale of shape', type=int, default=argparse.SUPPRESS, choices=[0, 1])
# Logging
parser.add_argument('--log_interval', help='Interval for logging, every X epochs', type=int, default=argparse.SUPPRESS)
parser.add_argument('--log_interval_im', help='Interval for logging renders image, every X epochs', type=int, default=argparse.SUPPRESS)
parser.add_argument('--log_elev', help='Logging elevation angle', type=float, default=argparse.SUPPRESS)
parser.add_argument('--log_fov', help='Logging field of view', type=float, default=argparse.SUPPRESS)
parser.add_argument('--log_dist', help='Logging distance from object', type=float, default=argparse.SUPPRESS)
parser.add_argument('--log_res', help='Logging render resolution', type=int, default=argparse.SUPPRESS)
parser.add_argument('--log_light_power', help='Light intensity for logging', type=float, default=argparse.SUPPRESS)
args = parser.parse_args()
if args.config is not None:
with open(args.config, 'r') as f:
try:
cfg = yaml.safe_load(f)
except yaml.YAMLError as e:
print(e)
for key in vars(args):
cfg[key] = vars(args)[key]
print(yaml.dump(cfg, default_flow_style=False))
random.seed(cfg['seed'])
os.environ['PYTHONHASHSEED'] = str(cfg['seed'])
np.random.seed(cfg['seed'])
torch.manual_seed(cfg['seed'])
torch.cuda.manual_seed(cfg['seed'])
torch.backends.cudnn.deterministic = True
loop(cfg)
print('Done')
if __name__ == '__main__':
main()