-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
97 lines (83 loc) · 4.28 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from models.layers.mesh import Mesh, PartMesh
from models.networks import init_net, sample_surface, local_nonuniform_penalty
import utils
import numpy as np
from models.losses import chamfer_distance, BeamGapLoss
from options import Options
import time
import os
options = Options()
opts = options.args
torch.manual_seed(opts.torch_seed)
device = torch.device('cuda:{}'.format(opts.gpu) if torch.cuda.is_available() else torch.device('cpu'))
print('device: {}'.format(device))
# initial mesh
mesh = Mesh(opts.initial_mesh, device=device, hold_history=True)
# input point cloud
input_xyz, input_normals = utils.read_pts(opts.input_pc)
# normalize point cloud based on initial mesh
input_xyz /= mesh.scale
input_xyz += mesh.translations[None, :]
input_xyz = torch.Tensor(input_xyz).type(options.dtype()).to(device)[None, :, :]
input_normals = torch.Tensor(input_normals).type(options.dtype()).to(device)[None, :, :]
part_mesh = PartMesh(mesh, num_parts=options.get_num_parts(len(mesh.faces)), bfs_depth=opts.overlap)
print(f'number of parts {part_mesh.n_submeshes}')
net, optimizer, rand_verts, scheduler = init_net(mesh, part_mesh, device, opts)
beamgap_loss = BeamGapLoss(device)
if opts.beamgap_iterations > 0:
print('beamgap on')
beamgap_loss.update_pm(part_mesh, torch.cat([input_xyz, input_normals], dim=-1))
for i in range(opts.iterations):
num_samples = options.get_num_samples(i % opts.upsamp)
if opts.global_step:
optimizer.zero_grad()
start_time = time.time()
for part_i, est_verts in enumerate(net(rand_verts, part_mesh)):
if not opts.global_step:
optimizer.zero_grad()
part_mesh.update_verts(est_verts[0], part_i)
num_samples = options.get_num_samples(i % opts.upsamp)
recon_xyz, recon_normals = sample_surface(part_mesh.main_mesh.faces, part_mesh.main_mesh.vs.unsqueeze(0), num_samples)
# calc chamfer loss w/ normals
recon_xyz, recon_normals = recon_xyz.type(options.dtype()), recon_normals.type(options.dtype())
xyz_chamfer_loss, normals_chamfer_loss = chamfer_distance(recon_xyz, input_xyz, x_normals=recon_normals, y_normals=input_normals,
unoriented=opts.unoriented)
if (i < opts.beamgap_iterations) and (i % opts.beamgap_modulo == 0):
loss = beamgap_loss(part_mesh, part_i)
else:
loss = (xyz_chamfer_loss + (opts.ang_wt * normals_chamfer_loss))
if opts.local_non_uniform > 0:
loss += opts.local_non_uniform * local_nonuniform_penalty(part_mesh.main_mesh).float()
loss.backward()
if not opts.global_step:
optimizer.step()
scheduler.step()
part_mesh.main_mesh.vs.detach_()
if opts.global_step:
optimizer.step()
scheduler.step()
end_time = time.time()
if i % 1 == 0:
print(f'{os.path.basename(opts.input_pc)}; iter: {i} out of: {opts.iterations}; loss: {loss.item():.4f};'
f' sample count: {num_samples}; time: {end_time - start_time:.2f}')
if i % opts.export_interval == 0 and i > 0:
print('exporting reconstruction... current LR: {}'.format(optimizer.param_groups[0]['lr']))
with torch.no_grad():
part_mesh.export(os.path.join(opts.save_path, f'recon_iter_{i}.obj'))
if (i > 0 and (i + 1) % opts.upsamp == 0):
mesh = part_mesh.main_mesh
num_faces = int(np.clip(len(mesh.faces) * 1.5, len(mesh.faces), opts.max_faces))
if num_faces > len(mesh.faces) or opts.manifold_always:
# up-sample mesh
mesh = utils.manifold_upsample(mesh, opts.save_path, Mesh,
num_faces=min(num_faces, opts.max_faces),
res=opts.manifold_res, simplify=True)
part_mesh = PartMesh(mesh, num_parts=options.get_num_parts(len(mesh.faces)), bfs_depth=opts.overlap)
print(f'upsampled to {len(mesh.faces)} faces; number of parts {part_mesh.n_submeshes}')
net, optimizer, rand_verts, scheduler = init_net(mesh, part_mesh, device, opts)
if i < opts.beamgap_iterations:
print('beamgap updated')
beamgap_loss.update_pm(part_mesh, input_xyz)
with torch.no_grad():
mesh.export(os.path.join(opts.save_path, 'last_recon.obj'))