-
Notifications
You must be signed in to change notification settings - Fork 20
/
Copy pathplot_results.m
65 lines (60 loc) · 2.16 KB
/
plot_results.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
function plot_results(expDir, datasetName, measures, savePath)
% Usage example: plot_results('exp', 'cifar', {'error'}, 'exp/summary.pdf');
if ~exist('datasetName', 'var') || isempty(datasetName),
datasetName = 'cifar';
end
if ~exist('measures', 'var') || isempty(measures),
if strcmpi(datasetName, 'cifar'), measures = {'error'};
elseif strcmpi(datasetName, 'imagenet'), measures = {'error', 'error5'};
end
end
if ~exist('savePath', 'var'),
savePath = expDir;
end
if ischar(measures), measures = {measures}; end
if isempty(strfind(savePath,'.pdf')) || strfind(savePath,'.pdf')~=numel(savePath)-3,
savePath = fullfile(savePath,[datasetName '-summary.pdf']);
end
plots = {'plain', 'resnet'};
figure(1) ; clf ;
cmap = lines;
for p = plots
p = char(p) ;
list = dir(fullfile(expDir,sprintf('%s-%s-*',datasetName,p)));
tokens = regexp({list.name}, sprintf('%s-%s-([\\d]+)',datasetName,p), 'tokens');
Ns = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens);
Ns = sort(Ns);
for k = 1:numel(measures),
subplot(k,numel(plots),find(strcmp(p,plots)));
hold on;
leg = {}; Hs = []; nEpoches = 0;
for n=Ns,
tmpDir = fullfile(expDir,sprintf('%s-%s-%d',datasetName,p,n));
epoch = findLastCheckpoint(tmpDir);
if epoch==0, continue; end
load(fullfile(tmpDir,sprintf('net-epoch-%d.mat',epoch)),'stats');
plot([stats.train.(measures{k})], ':','Color',cmap(find(Ns==n),:),'LineWidth',1.5);
Hs(end+1) = plot([stats.val.(measures{k})], '-','Color',cmap(find(Ns==n),:),'LineWidth',1.5);
leg{end+1} = sprintf('%s-%d',p,6*n+2);
if numel(stats.train)>nEpoches, nEpoches = numel(stats.train); end
end
xlabel('epoch') ;
ylabel(sprintf('%s', measures{k}));
title(p) ;
legend(Hs,leg{:},'Location','NorthEast') ;
% axis square;
% ylim([0 .25]);
ylim([0 .75]);
xlim([1 nEpoches]);
set(gca,'YGrid','on');
end
end
drawnow ;
print(1, savePath, '-dpdf') ;
end
function epoch = findLastCheckpoint(modelDir)
list = dir(fullfile(modelDir, 'net-epoch-*.mat')) ;
tokens = regexp({list.name}, 'net-epoch-([\d]+).mat', 'tokens') ;
epoch = cellfun(@(x) sscanf(x{1}{1}, '%d'), tokens) ;
epoch = max([epoch 0]) ;
end