Skip to content

Commit

Permalink
update MAML code
Browse files Browse the repository at this point in the history
  • Loading branch information
RoyalSkye committed Sep 19, 2022
1 parent cddcdb5 commit 5b470c6
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 169 deletions.
238 changes: 133 additions & 105 deletions POMO/TSP/TSPModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ def __init__(self, **model_params):
self.encoded_nodes = None
# shape: (batch, problem, EMBEDDING_DIM)

def pre_forward(self, reset_state):
self.encoded_nodes = self.encoder(reset_state.problems)
def pre_forward(self, reset_state, weights=None):
self.encoded_nodes = self.encoder(reset_state.problems, weights=weights)
# shape: (batch, problem, EMBEDDING_DIM)
self.decoder.set_kv(self.encoded_nodes)
self.decoder.set_kv(self.encoded_nodes, weights=weights)

def forward(self, state):
def forward(self, state, weights=None):
batch_size = state.BATCH_IDX.size(0)
pomo_size = state.BATCH_IDX.size(1)

Expand All @@ -30,12 +30,12 @@ def forward(self, state):

encoded_first_node = _get_encoding(self.encoded_nodes, selected)
# shape: (batch, pomo, embedding)
self.decoder.set_q1(encoded_first_node) # pre-compute fixed part of the context embedding
self.decoder.set_q1(encoded_first_node, weights=weights) # pre-compute fixed part of the context embedding

else:
encoded_last_node = _get_encoding(self.encoded_nodes, state.current_node)
# shape: (batch, pomo, embedding)
probs = self.decoder(encoded_last_node, ninf_mask=state.ninf_mask)
probs = self.decoder(encoded_last_node, ninf_mask=state.ninf_mask, weights=weights)
# shape: (batch, pomo, problem)

while True:
Expand Down Expand Up @@ -86,15 +86,19 @@ def __init__(self, **model_params):
self.embedding = nn.Linear(2, embedding_dim)
self.layers = nn.ModuleList([EncoderLayer(**model_params) for _ in range(encoder_layer_num)])

def forward(self, data):
# data.shape: (batch, problem, 2)

embedded_input = self.embedding(data)
# shape: (batch, problem, embedding)

out = embedded_input
for layer in self.layers:
out = layer(out)
def forward(self, data, weights=None):
if weights is None:
# data.shape: (batch, problem, 2)
embedded_input = self.embedding(data)
# shape: (batch, problem, embedding)
out = embedded_input
for layer in self.layers:
out = layer(out)
else:
embedded_input = F.linear(data, weights['encoder.embedding.weight'], weights['encoder.embedding.bias'])
out = embedded_input
for idx, layer in enumerate(self.layers):
out = layer(out, weights=weights, index=idx)

return out

Expand All @@ -116,24 +120,32 @@ def __init__(self, **model_params):
self.feedForward = Feed_Forward_Module(**model_params)
self.addAndNormalization2 = Add_And_Normalization_Module(**model_params)

def forward(self, input1):
# input.shape: (batch, problem, EMBEDDING_DIM)
head_num = self.model_params['head_num']

q = reshape_by_heads(self.Wq(input1), head_num=head_num)
k = reshape_by_heads(self.Wk(input1), head_num=head_num)
v = reshape_by_heads(self.Wv(input1), head_num=head_num)
# q shape: (batch, HEAD_NUM, problem, KEY_DIM)

out_concat = multi_head_attention(q, k, v)
# shape: (batch, problem, HEAD_NUM*KEY_DIM)

multi_head_out = self.multi_head_combine(out_concat)
# shape: (batch, problem, EMBEDDING_DIM)

out1 = self.addAndNormalization1(input1, multi_head_out)
out2 = self.feedForward(out1)
out3 = self.addAndNormalization2(out1, out2)
def forward(self, input1, weights=None, index=0):
if weights is None:
# input.shape: (batch, problem, EMBEDDING_DIM)
head_num = self.model_params['head_num']
q = reshape_by_heads(self.Wq(input1), head_num=head_num)
k = reshape_by_heads(self.Wk(input1), head_num=head_num)
v = reshape_by_heads(self.Wv(input1), head_num=head_num)
# q shape: (batch, HEAD_NUM, problem, KEY_DIM)
out_concat = multi_head_attention(q, k, v)
# shape: (batch, problem, HEAD_NUM*KEY_DIM)
multi_head_out = self.multi_head_combine(out_concat)
# shape: (batch, problem, EMBEDDING_DIM)
out1 = self.addAndNormalization1(input1, multi_head_out)
out2 = self.feedForward(out1)
out3 = self.addAndNormalization2(out1, out2)
else:
head_num = self.model_params['head_num']
q = reshape_by_heads(F.linear(input1, weights['encoder.layers.{}.Wq.weight'.format(index)], bias=None), head_num=head_num)
k = reshape_by_heads(F.linear(input1, weights['encoder.layers.{}.Wk.weight'.format(index)], bias=None), head_num=head_num)
v = reshape_by_heads(F.linear(input1, weights['encoder.layers.{}.Wv.weight'.format(index)], bias=None), head_num=head_num)
out_concat = multi_head_attention(q, k, v)
multi_head_out = F.linear(out_concat, weights['encoder.layers.{}.multi_head_combine.weight'.format(index)], weights['encoder.layers.{}.multi_head_combine.bias'.format(index)])
out1 = self.addAndNormalization1(input1, multi_head_out, weights={'weight': weights['encoder.layers.{}.addAndNormalization1.norm.weight'.format(index)], 'bias': weights['encoder.layers.{}.addAndNormalization1.norm.bias'.format(index)]})
out2 = self.feedForward(out1, weights={'weight1': weights['encoder.layers.{}.feedForward.W1.weight'.format(index)], 'bias1': weights['encoder.layers.{}.feedForward.W1.bias'.format(index)],
'weight2': weights['encoder.layers.{}.feedForward.W2.weight'.format(index)], 'bias2': weights['encoder.layers.{}.feedForward.W2.bias'.format(index)]})
out3 = self.addAndNormalization2(out1, out2, weights={'weight': weights['encoder.layers.{}.addAndNormalization2.norm.weight'.format(index)], 'bias': weights['encoder.layers.{}.addAndNormalization2.norm.bias'.format(index)]})

return out3
# shape: (batch, problem, EMBEDDING_DIM)
Expand Down Expand Up @@ -163,60 +175,71 @@ def __init__(self, **model_params):
self.single_head_key = None # saved, for single-head attention
self.q_first = None # saved q1, for multi-head attention

def set_kv(self, encoded_nodes):
# encoded_nodes.shape: (batch, problem, embedding)
head_num = self.model_params['head_num']

self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)
self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)
# shape: (batch, head_num, pomo, qkv_dim)
self.single_head_key = encoded_nodes.transpose(1, 2)
# shape: (batch, embedding, problem)

def set_q1(self, encoded_q1):
# encoded_q.shape: (batch, n, embedding) # n can be 1 or pomo
head_num = self.model_params['head_num']

self.q_first = reshape_by_heads(self.Wq_first(encoded_q1), head_num=head_num)
# shape: (batch, head_num, n, qkv_dim)

def forward(self, encoded_last_node, ninf_mask):
# encoded_last_node.shape: (batch, pomo, embedding)
# ninf_mask.shape: (batch, pomo, problem)

head_num = self.model_params['head_num']

# Multi-Head Attention
#######################################################
q_last = reshape_by_heads(self.Wq_last(encoded_last_node), head_num=head_num)
# shape: (batch, head_num, pomo, qkv_dim)

q = self.q_first + q_last
# shape: (batch, head_num, pomo, qkv_dim)

out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)
# shape: (batch, pomo, head_num*qkv_dim)

mh_atten_out = self.multi_head_combine(out_concat)
# shape: (batch, pomo, embedding)

# Single-Head Attention, for probability calculation
#######################################################
score = torch.matmul(mh_atten_out, self.single_head_key)
# shape: (batch, pomo, problem)

sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']
logit_clipping = self.model_params['logit_clipping']

score_scaled = score / sqrt_embedding_dim
# shape: (batch, pomo, problem)

score_clipped = logit_clipping * torch.tanh(score_scaled)

score_masked = score_clipped + ninf_mask

probs = F.softmax(score_masked, dim=2)
# shape: (batch, pomo, problem)
def set_kv(self, encoded_nodes, weights=None):
if weights is None:
# encoded_nodes.shape: (batch, problem, embedding)
head_num = self.model_params['head_num']
self.k = reshape_by_heads(self.Wk(encoded_nodes), head_num=head_num)
self.v = reshape_by_heads(self.Wv(encoded_nodes), head_num=head_num)
# shape: (batch, head_num, pomo, qkv_dim)
self.single_head_key = encoded_nodes.transpose(1, 2)
# shape: (batch, embedding, problem)
else:
head_num = self.model_params['head_num']
self.k = reshape_by_heads(F.linear(encoded_nodes, weights['decoder.Wk.weight'], bias=None), head_num=head_num)
self.v = reshape_by_heads(F.linear(encoded_nodes, weights['decoder.Wv.weight'], bias=None), head_num=head_num)
self.single_head_key = encoded_nodes.transpose(1, 2)

def set_q1(self, encoded_q1, weights=None):
if weights is None:
# encoded_q.shape: (batch, n, embedding) # n can be 1 or pomo
head_num = self.model_params['head_num']
self.q_first = reshape_by_heads(self.Wq_first(encoded_q1), head_num=head_num)
# shape: (batch, head_num, n, qkv_dim)
else:
head_num = self.model_params['head_num']
self.q_first = reshape_by_heads(F.linear(encoded_q1, weights['decoder.Wq_first.weight'], bias=None), head_num=head_num)

def forward(self, encoded_last_node, ninf_mask, weights=None):
if weights is None:
# encoded_last_node.shape: (batch, pomo, embedding)
# ninf_mask.shape: (batch, pomo, problem)
head_num = self.model_params['head_num']
# Multi-Head Attention
#######################################################
q_last = reshape_by_heads(self.Wq_last(encoded_last_node), head_num=head_num)
# shape: (batch, head_num, pomo, qkv_dim)
q = self.q_first + q_last
# shape: (batch, head_num, pomo, qkv_dim)
out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)
# shape: (batch, pomo, head_num*qkv_dim)
mh_atten_out = self.multi_head_combine(out_concat)
# shape: (batch, pomo, embedding)
# Single-Head Attention, for probability calculation
#######################################################
score = torch.matmul(mh_atten_out, self.single_head_key)
# shape: (batch, pomo, problem)
sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']
logit_clipping = self.model_params['logit_clipping']
score_scaled = score / sqrt_embedding_dim
# shape: (batch, pomo, problem)
score_clipped = logit_clipping * torch.tanh(score_scaled)
score_masked = score_clipped + ninf_mask
probs = F.softmax(score_masked, dim=2)
# shape: (batch, pomo, problem)
else:
head_num = self.model_params['head_num']
q_last = reshape_by_heads(F.linear(encoded_last_node, weights['decoder.Wq_last.weight'], bias=None), head_num=head_num)
q = self.q_first + q_last
out_concat = multi_head_attention(q, self.k, self.v, rank3_ninf_mask=ninf_mask)
mh_atten_out = F.linear(out_concat, weights['decoder.multi_head_combine.weight'], weights['decoder.multi_head_combine.bias'])
score = torch.matmul(mh_atten_out, self.single_head_key)
sqrt_embedding_dim = self.model_params['sqrt_embedding_dim']
logit_clipping = self.model_params['logit_clipping']
score_scaled = score / sqrt_embedding_dim
score_clipped = logit_clipping * torch.tanh(score_scaled)
score_masked = score_clipped + ninf_mask
probs = F.softmax(score_masked, dim=2)

return probs

Expand Down Expand Up @@ -283,20 +306,22 @@ def __init__(self, **model_params):
embedding_dim = model_params['embedding_dim']
self.norm = nn.InstanceNorm1d(embedding_dim, affine=True, track_running_stats=False)

def forward(self, input1, input2):
# input.shape: (batch, problem, embedding)

added = input1 + input2
# shape: (batch, problem, embedding)

transposed = added.transpose(1, 2)
# shape: (batch, embedding, problem)

normalized = self.norm(transposed)
# shape: (batch, embedding, problem)

back_trans = normalized.transpose(1, 2)
# shape: (batch, problem, embedding)
def forward(self, input1, input2, weights=None):
if weights is None:
# input.shape: (batch, problem, embedding)
added = input1 + input2
# shape: (batch, problem, embedding)
transposed = added.transpose(1, 2)
# shape: (batch, embedding, problem)
normalized = self.norm(transposed)
# shape: (batch, embedding, problem)
back_trans = normalized.transpose(1, 2)
# shape: (batch, problem, embedding)
else:
added = input1 + input2
transposed = added.transpose(1, 2)
normalized = F.instance_norm(transposed, weight=weights['weight'], bias=weights['bias'])
back_trans = normalized.transpose(1, 2)

return back_trans

Expand All @@ -310,7 +335,10 @@ def __init__(self, **model_params):
self.W1 = nn.Linear(embedding_dim, ff_hidden_dim)
self.W2 = nn.Linear(ff_hidden_dim, embedding_dim)

def forward(self, input1):
# input.shape: (batch, problem, embedding)

return self.W2(F.relu(self.W1(input1)))
def forward(self, input1, weights=None):
if weights is None:
# input.shape: (batch, problem, embedding)
return self.W2(F.relu(self.W1(input1)))
else:
output = F.relu(F.linear(input1, weights['weight1'], bias=weights['bias1']))
return F.linear(output, weights['weight2'], bias=weights['bias2'])
Loading

0 comments on commit 5b470c6

Please sign in to comment.