-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmoudle.py
22 lines (20 loc) · 1.02 KB
/
moudle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import torch
import torch.nn as nn
class Highway(nn.Module):
def __init__(self, in_features, out_features, func, n_layers=1, bias=True):
super().__init__()
self.n_layers = n_layers
self.func = func
self.nonlinear = nn.ModuleList([nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
for _ in range(n_layers)])
self.linear = nn.ModuleList([nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
for _ in range(n_layers)])
self.gate = nn.ModuleList([nn.Linear(in_features=in_features, out_features=out_features, bias=bias)
for _ in range(n_layers)])
def forward(self, x):
for i in range(self.n_layers):
gate = torch.sigmoid(self.gate[i](x))
linear = self.linear[i](x)
nonlinear = self.func(self.nonlinear[i](x))
x = gate * nonlinear + (1 - gate) * linear
return x