-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathmultitrain.py
131 lines (109 loc) · 4.1 KB
/
multitrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import datetime
import subprocess
import sys
import traceback
from multiprocessing import Pool
from multiprocessing.managers import BaseManager
from os import listdir
from os.path import isfile, join
from dcgan_cmd_builder import *
from files_utils import backup_checkpoint, must_backup_checkpoint
from shared_state import ThreadsSharedState
from video_utils import process_videos_job_param, periodic_render_job
class MyManager(BaseManager):
pass
if not os.path.exists('renders'):
os.makedirs('renders')
BaseManager.register('ThreadsSharedState', ThreadsSharedState)
manager = BaseManager()
manager.start()
# define constants
fps = 60
samples_prefix = 'samples_'
data_folders = [f for f in listdir('data/')]
csv_files = [f for f in listdir('.') if (isfile(join('.', f)) and f.endswith(".csv"))]
csv_files.sort()
pool = Pool(processes=10)
csv_file = None
gpu_idx = None
enable_cache = True
if len(sys.argv) > 1:
params = sys.argv[0:]
for idx, param in enumerate(params):
if param == '--config':
csv_file = params[idx + 1]
print('csv file == {}'.format(params[idx + 1]))
if param == '--gpu_idx':
gpu_idx = params[idx + 1]
print('gpu_idx == {}'.format(params[idx + 1]))
if param == '--disable_cache':
enable_cache = False
# validate csv config file
if csv_file is None:
if len(csv_files) == 0:
print('Error: no csv file')
exit(1)
csv_file = csv_files[0]
print('found config file: ' + csv_file)
else:
print('config file passed in param: ' + csv_file)
print()
# parse and validate jobs
jobs = Job.from_csv_file(csv_file)
Job.validate(jobs)
# launch schedule job if needed
auto_periodic_renders = Job.must_start_auto_periodic_renders(jobs)
shared_state = None
if auto_periodic_renders:
# noinspection PyUnresolvedReferences
shared_state = manager.ThreadsSharedState()
pool.apply(periodic_render_job, args=[shared_state, True])
# run the jobs
for job in jobs:
try:
print('')
if job.has_auto_periodic_render:
# TODO: pass the job object directly
shared_state.init_current_cut()
shared_state.set_folder(job.sample_folder)
shared_state.set_job_name(job.name)
shared_state.set_frames_threshold(job.auto_render_period * fps)
shared_state.set_upload_to_ftp(job.upload_to_ftp)
shared_state.set_delete_at_the_end(job.delete_images_after_render)
print('frames threshold: {}'.format(shared_state.get_frames_threshold()))
print('sample folder: {}'.format(shared_state.get_sample_folder()))
print('dataset size: {}'.format(job.dataset_size))
print('video length: {:0.2f} min.'.format(job.video_length))
print('frames per minutes: {}'.format(fps * 60))
print('automatic periodic render: {}'.format(auto_periodic_renders))
print('sample resolution: {}'.format(job.sample_res))
if job.render_res is not None:
print('render resolution: {}'.format(job.render_res))
print('boxes: {}'.format(job.get_boxes()))
if job.has_auto_periodic_render:
shared_state.set_sample_res(job.sample_res)
shared_state.set_render_res(job.render_res)
print('')
begin = datetime.datetime.now().replace(microsecond=0)
job_cmd = job.build_job_command(gpu_idx, enable_cache)
print('command: ' + ' '.join('{}'.format(v) for v in job_cmd))
process = subprocess.run(job_cmd)
print('return code: {}'.format(process.returncode))
duration = datetime.datetime.now().replace(microsecond=0) - begin
print('duration of the job {} -> {}'.format(job.name, duration))
# process video async
if process.returncode == 0:
if job.render_video:
if auto_periodic_renders:
os.rename(job.sample_folder, shared_state.get_time_cut_folder_name())
job.sample_folder = shared_state.get_time_cut_folder_name()
# pool.apply_async(process_video_job_param, job)
process_videos_job_param(job)
# backup checkpoint one last time
if must_backup_checkpoint() and job.use_checkpoints:
backup_checkpoint(job.name)
except Exception as e:
print('error during process of {} -> {}'.format(job.name, e))
print(traceback.format_exc())
pool.close()
pool.join()