-
Notifications
You must be signed in to change notification settings - Fork 317
/
Copy pathcnn_load_pretrain.m
65 lines (55 loc) · 1.94 KB
/
cnn_load_pretrain.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
% FILE: cnn_load_pretrain.m
%
% This function takes an input network and fill its parameter weights based
% on a pretrained network. Parameters are matched based on names. If target
% network is empty, then use the structure of pretrained network. It happens
% when cnn_init.m fails to initialize the model structure.
%
% INPUT: net (target network)
% prepath (path to a network with pretrained weights)
%
% OUTPUT: net (target network with pretrained weights)
function net = cnn_load_pretrain(net, prepath)
% convert pretrained network to DagNN (easy indexing)
prenet_ = load(prepath);
if isfield(prenet_, 'net')
prenet_ = prenet_.net;
end
if isfield(prenet_, 'params')
prenet = dagnn.DagNN.loadobj(prenet_);
else
prenet = dagnn.DagNN.fromSimpleNN(prenet_);
end
clear prenet_;
% same canonical param name
if isempty(net)
net = prenet;
else
for i = 1:numel(net.params)
idx = prenet.getParamIndex(net.params(i).name);
if ~isnan(idx)
net.params(i).value = prenet.params(idx).value;
end
end
end
if isempty(net.getLayerIndex('drop6'))
net.addLayer('drop6', dagnn.DropOut('rate', 0.5), 'fc6x', 'fc6xd');
net.setLayerInputs('fc7', {'fc6xd'});
end
if isempty(net.getLayerIndex('drop7'))
net.addLayer('drop7', dagnn.DropOut('rate', 0.5), 'fc7x', 'fc7xd');
net.setLayerInputs('score_fr', {'fc7xd'});
end
% remove average image
net.meta.normalization.averageImage = [];
% NOTE Reshape multipliers and biases in BN to be vectors instead of
% 1x1xK matrices. Otherwise, there will be errors in cnn_train_dag.m
for i = 1:numel(net.layers)
if isa(net.layers(i).block, 'dagnn.BatchNorm')
midx = net.getParamIndex(net.layers(i).params{1}); % multiplier
bidx = net.getParamIndex(net.layers(i).params{2}); % bias
% vectorize
net.params(midx).value = net.params(midx).value(:);
net.params(bidx).value = net.params(bidx).value(:);
end
end