-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathutils.py
120 lines (109 loc) · 3.42 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
'''Some helper functions for PyTorch, including:
- get_mean_and_std: calculate the mean and std value of dataset.
- msr_init: net parameter initialization.
- progress_bar: progress bar mimic xlua.progress.
'''
import os
import sys
import time
import math
import torch
import numpy as np
import torch.nn as nn
import torch.nn.init as init
import ast
def read_unknowns(unknown_list):
"""
input is of form ['--METHOD.MODEL.LR=0.001758722642964502', '--METHOD.MODEL.NUM_LAYERS=1']
"""
ret = {}
for item in unknown_list:
key, value = item.split('=')
try:
value = ast.literal_eval(value)
except:
print("MALFORMED ", value)
k = key[2:]
ret[k] = value
return ret
def nest_dict(flat_dict, sep='.'):
"""Return nested dict by splitting the keys on a delimiter.
>>> from pprint import pprint
>>> pprint(nest_dict({'title': 'foo', 'author_name': 'stretch',
... 'author_zipcode': '06901'}))
{'author': {'name': 'stretch', 'zipcode': '06901'}, 'title': 'foo'}
"""
tree = {}
for key, val in flat_dict.items():
t = tree
prev = None
for part in key.split(sep):
if prev is not None:
t = t.setdefault(prev, {})
prev = part
else:
t.setdefault(prev, val)
return tree
def get_counts(labels):
values, counts = np.unique(labels, return_counts=True)
sorted_tuples = zip(*sorted(zip(values, counts))) # this just ensures we are getting the counts in the sorted order of the keys
values, counts = [ list(tuple) for tuple in sorted_tuples]
fracs = 1 / torch.Tensor(counts)
return fracs / torch.max(fracs)
def get_mean_and_std(dataset):
'''Compute the mean and std value of dataset.'''
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
mean = torch.zeros(3)
std = torch.zeros(3)
print('==> Computing mean and std..')
for inputs, targets in dataloader:
for i in range(3):
mean[i] += inputs[:,i,:,:].mean()
std[i] += inputs[:,i,:,:].std()
mean.div_(len(dataset))
std.div_(len(dataset))
return mean, std
def init_params(net):
'''Init layer parameters.'''
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight, mode='fan_out')
if m.bias:
init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant(m.weight, 1)
init.constant(m.bias, 0)
elif isinstance(m, nn.Linear):
init.normal(m.weight, std=1e-3)
if m.bias:
init.constant(m.bias, 0)
def format_time(seconds):
days = int(seconds / 3600/24)
seconds = seconds - days*3600*24
hours = int(seconds / 3600)
seconds = seconds - hours*3600
minutes = int(seconds / 60)
seconds = seconds - minutes*60
secondsf = int(seconds)
seconds = seconds - secondsf
millis = int(seconds*1000)
f = ''
i = 1
if days > 0:
f += str(days) + 'D'
i += 1
if hours > 0 and i <= 2:
f += str(hours) + 'h'
i += 1
if minutes > 0 and i <= 2:
f += str(minutes) + 'm'
i += 1
if secondsf > 0 and i <= 2:
f += str(secondsf) + 's'
i += 1
if millis > 0 and i <= 2:
f += str(millis) + 'ms'
i += 1
if f == '':
f = '0ms'
return f