-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathhelper_networks.py
61 lines (45 loc) · 1.83 KB
/
helper_networks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
import torch.nn as nn
def maxpool(x, dim=-1, keepdim=False):
out, _ = x.max(dim=dim, keepdim=keepdim)
return out
class SimplePointnet(nn.Module):
'''
implementation from Occupancy Networks code
https://github.com/autonomousvision/occupancy_networks/blob/master/im2mesh/encoder/pointnet.py
PointNet-based encoder network.
Args:
c_dim (int): dimension of latent code c
dim (int): input points dimension
hidden_dim (int): hidden dimension of the network
'''
def __init__(self, c_dim=128, dim=3, hidden_dim=128):
super().__init__()
self.c_dim = c_dim
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
self.fc_0 = nn.Linear(2*hidden_dim, hidden_dim)
self.fc_1 = nn.Linear(2*hidden_dim, hidden_dim)
self.fc_2 = nn.Linear(2*hidden_dim, hidden_dim)
self.fc_3 = nn.Linear(2*hidden_dim, hidden_dim)
self.fc_c = nn.Linear(hidden_dim, c_dim)
self.actvn = nn.ReLU()
self.pool = maxpool
def forward(self, p):
batch_size, T, D = p.size()
# T is number of points. B is batch size, D is n features per point
# output size: B x T X F
net = self.fc_pos(p)
net = self.fc_0(self.actvn(net))
pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())
net = torch.cat([net, pooled], dim=2)
net = self.fc_1(self.actvn(net))
pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())
net = torch.cat([net, pooled], dim=2)
net = self.fc_2(self.actvn(net))
pooled = self.pool(net, dim=1, keepdim=True).expand(net.size())
net = torch.cat([net, pooled], dim=2)
net = self.fc_3(self.actvn(net))
# pool to B x F
net = self.pool(net, dim=1)
c = self.fc_c(self.actvn(net))
return c