-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathneural_highlighter.py
30 lines (27 loc) · 1.09 KB
/
neural_highlighter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch.nn as nn
from utils import FourierFeatureTransform
class NeuralHighlighter(nn.Module):
def __init__(self, depth, width, out_dim, input_dim=3, positional_encoding=False, sigma=5.0):
super(NeuralHighlighter, self).__init__()
layers = []
if positional_encoding:
layers.append(FourierFeatureTransform(input_dim, width, sigma))
layers.append(nn.Linear(width * 2 + input_dim, width))
layers.append(nn.ReLU())
layers.append(nn.LayerNorm([width]))
else:
layers.append(nn.Linear(input_dim, width))
layers.append(nn.ReLU())
layers.append(nn.LayerNorm([width]))
for i in range(depth):
layers.append(nn.Linear(width, width))
layers.append(nn.ReLU())
layers.append(nn.LayerNorm([width]))
layers.append(nn.Linear(width, out_dim))
layers.append(nn.Softmax(dim=1))
self.mlp = nn.ModuleList(layers)
print(self.mlp)
def forward(self, x):
for layer in self.mlp:
x = layer(x)
return x