-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathFocusing.py
70 lines (47 loc) · 2.21 KB
/
Focusing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import tensorflow as tf
import numpy as np
def ContentFocusing(k_t, M_t, b_t, K = None):
'''
k_t : (M,), Key Vector generated by EITHER HEAD (in whichever HEAD this function is used in for addressing)
M_t : (N,M), Memory Matrix at time t.
b_t : Scalar, Key Strength hyperparameter
K : Function, Similarity Measure, if None, Cosine Similarity will be used.
RETURNS:
w_ct : (N,), Weighting after Content Focusing.
'''
N,M = M_t.shape
assert k_t.shape == (M,)
if K == None :
def Cosine_Similarity(u,v):
u,v = tf.reshape(u,(1,-1)),tf.reshape(v,(-1,1))
return np.dot(u,v)/(np.linalg.norm(u)*np.linalg.norm(v))
K = Cosine_Similarity
Applied_K_Vector = np.apply_along_axis(K,1,M_t,tf.reshape(k_t,(1,M))).reshape(-1,1)
exp_of_AKV = tf.exp(b_t * Applied_K_Vector) #AKV for Applied K Vector
w_ct = exp_of_AKV/np.sum(exp_of_AKV)
assert w_ct.shape == (N,1)
return tf.reshape(w_ct,(N,))
def LocationFocusing( k_t, M_t, b_t, g_t, w_prev, s_t, gamma_t, K = None):
'''
k_t, M_t, b_t, K : SAME AS IN CONTENT FOCUSING
g_t : Scalar, Interpolation Gate in the range (0,1) emitted by HEAD IN USE.
w_prev : (N,), Weight Vector produced by the HEAD IN USE at the previous time step.
s_t : (len(shift_range),), The weights emitted by the HEAD IN USE that defines the normalized distribution over the allowed integer shifts (which is shift_range object)
gamma_t : Scalar, Sharpening Factor >= 1
RETURNS:
w_t : (N,), Final Weight Vector
'''
w_ct = ContentFocusing(k_t, M_t, b_t, K)
N,M = M_t.shape
assert w_prev.shape == (N,)
#Interpolation
w_gt = g_t * w_ct + (1 - g_t) * w_prev
#Convolutional Shift
w_hat_t = np.zeros(N) #These loops will limit the speed clearly, it would be good to wrap them in C (or find an alternative function)
for i in range(N):
for j in range(N):
w_hat_t[i] += w_gt[j]*s_t[(i-j)%N]
#Sharpening
powered = tf.pow(w_hat_t,gamma_t)
w_t = powered/np.sum(powered)
return w_t