forked from VainF/Torch-Pruning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetapruner.py
378 lines (322 loc) · 18.2 KB
/
metapruner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
import torch
import torch.nn as nn
import typing, warnings
from .scheduler import linear_scheduler
from ..import function
from ... import ops, dependency
class MetaPruner:
"""
Meta pruner for structural pruning.
Args:
# Basic
* model (nn.Module): A to-be-pruned model
* example_inputs (torch.Tensor or List): dummy inputs for graph tracing.
* importance (Callable): importance estimator.
* global_pruning (bool): enable global pruning. Default: False.
* ch_sparsity (float): global channel sparisty. Also known as pruning ratio. Default: 0.5.
* ch_sparsity_dict (Dict[nn.Module, float]): layer-specific sparsity. Will cover ch_sparsity if specified. Default: None.
* max_ch_sparsity (float): maximum channel sparsity. Default: 1.0.
* iterative_steps (int): number of steps for iterative pruning. Default: 1.
* iterative_sparsity_scheduler (Callable): scheduler for iterative pruning. Default: linear_scheduler.
* ignored_layers (List[nn.Module | typing.Type]): ignored modules. Default: None.
* round_to (int): round channels to the nearest multiple of round_to. E.g., round_to=8 means channels will be rounded to 8x. Default: None.
# Adavanced
* customized_pruners (dict): a dict containing module-pruner pairs. Default: None.
* unwrapped_parameters (dict): a dict containing unwrapped parameters & pruning dims. Default: None.
* root_module_types (list): types of prunable modules. Default: [nn.Conv2d, nn.Linear, nn.LSTM].
* forward_fn (Callable): A function to execute model.forward. Default: None.
* output_transform (Callable): A function to transform network outputs. Default: None.
"""
def __init__(
self,
# Basic
model: nn.Module, # a simple pytorch model
example_inputs: torch.Tensor, # a dummy input for graph tracing. Should be on the same
importance: typing.Callable, # tp.importance.Importance for group importance estimation
global_pruning: bool = False, # https://pytorch.org/tutorials/intermediate/pruning_tutorial.html#global-pruning.
ch_sparsity: float = 0.5, # channel/dim sparsity, also known as pruning ratio
ch_sparsity_dict: typing.Dict[nn.Module, float] = None, # layer-specific sparsity, will cover ch_sparsity if specified
max_ch_sparsity: float = 1.0, # maximum sparsity. useful if over-pruning happens.
iterative_steps: int = 1, # for iterative pruning
iterative_sparsity_scheduler: typing.Callable = linear_scheduler, # scheduler for iterative pruning.
ignored_layers: typing.List[nn.Module] = None, # ignored layers
round_to: int = None, # round channels to the nearest multiple of round_to
# Advanced
in_channel_groups: typing.Dict[nn.Module, int] = dict(), # The number of channel groups for layer input
out_channel_groups: typing.Dict[nn.Module, int] = dict(), # The number of channel groups for layer output
customized_pruners: typing.Dict[typing.Any, function.BasePruningFunc] = None, # pruners for customized layers. E.g., {nn.Linear: my_linear_pruner}
unwrapped_parameters: typing.Dict[nn.Parameter, int] = None, # unwrapped nn.Parameters & pruning_dims. For example, {ViT.pos_emb: 0}
root_module_types: typing.List = [ops.TORCH_CONV, ops.TORCH_LINEAR, ops.TORCH_LSTM], # root module for each group
forward_fn: typing.Callable = None, # a function to execute model.forward
output_transform: typing.Callable = None, # a function to transform network outputs
# deprecated
channel_groups: typing.Dict[nn.Module, int] = dict(), # channel groups for layers
):
self.model = model
self.importance = importance
self.ch_sparsity = ch_sparsity
self.ch_sparsity_dict = ch_sparsity_dict if ch_sparsity_dict is not None else {}
self.max_ch_sparsity = max_ch_sparsity
self.global_pruning = global_pruning
if len(channel_groups) > 0:
warnings.warn("channel_groups is deprecated. Please use in_channel_groups and out_channel_groups instead.")
out_channel_groups.update(channel_groups)
self.in_channel_groups = in_channel_groups
self.out_channel_groups = out_channel_groups
self.root_module_types = root_module_types
self.round_to = round_to
###############################################
# Build dependency graph
self.DG = dependency.DependencyGraph().build_dependency(
model,
example_inputs=example_inputs,
forward_fn=forward_fn,
output_transform=output_transform,
unwrapped_parameters=unwrapped_parameters,
customized_pruners=customized_pruners,
)
###############################################
# Ignored layers and submodules
self.ignored_layers = []
if ignored_layers is not None:
for layer in ignored_layers:
self.ignored_layers.extend(list(layer.modules()))
###############################################
# Iterative pruning
# The pruner will prune the model iteratively for several steps to achieve the target sparsity
# E.g., if iterative_steps=5, ch_sparsity=0.5, the sparsity of each step will be [0.1, 0.2, 0.3, 0.4, 0.5]
self.iterative_steps = iterative_steps
self.iterative_sparsity_scheduler = iterative_sparsity_scheduler
self.current_step = 0
# channel sparsity for each iterative step
self.per_step_ch_sparsity = self.iterative_sparsity_scheduler(
self.ch_sparsity, self.iterative_steps
)
###############################################
# Layer-specific sparsity. Will cover the global sparsity if specified
self.ch_sparsity_dict = {}
if ch_sparsity_dict is not None:
for module in ch_sparsity_dict:
sparsity = ch_sparsity_dict[module]
for submodule in module.modules():
prunable_types = tuple([ops.type2class(
prunable_type) for prunable_type in self.DG.REGISTERED_PRUNERS.keys()])
if isinstance(submodule, prunable_types):
self.ch_sparsity_dict[submodule] = self.iterative_sparsity_scheduler(
sparsity, self.iterative_steps
)
###############################################
# Detect group convs & group norms
for m in self.model.modules():
layer_pruner = self.DG.get_pruner_of_module(m)
in_ch_group = layer_pruner.get_in_channel_groups(m)
out_ch_group = layer_pruner.get_out_channel_groups(m)
if isinstance(m, ops.TORCH_CONV) and m.groups == m.out_channels:
continue
if in_ch_group > 1:
self.in_channel_groups[m] = in_ch_group
if out_ch_group > 1:
self.out_channel_groups[m] = out_ch_group
###############################################
# Initial channels/dims of each layer
self.layer_init_out_ch = {}
self.layer_init_in_ch = {}
for m in self.DG.module2node.keys():
if ops.module2type(m) in self.DG.REGISTERED_PRUNERS:
self.layer_init_out_ch[m] = self.DG.get_out_channels(m)
self.layer_init_in_ch[m] = self.DG.get_in_channels(m)
###############################################
# Count the number of total channels at initialization
if self.global_pruning:
initial_total_channels = 0
for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):
group = self._downstream_node_as_root_if_unbind(group)
initial_total_channels += ( (self.DG.get_out_channels(group[0][0].target.module) ) // self._get_channel_groups(group) )
self.initial_total_channels = initial_total_channels
def step(self, interactive=False)-> typing.Union[typing.Generator, None]:
self.current_step += 1
pruning_method = self.prune_global if self.global_pruning else self.prune_local
if interactive: # yield groups for interactive pruning
return pruning_method()
else:
for group in pruning_method():
group.prune()
def estimate_importance(self, group, ch_groups=1) -> torch.Tensor:
return self.importance(group, ch_groups=ch_groups)
def pruning_history(self) -> typing.List[typing.Tuple[str, bool, typing.Union[list, tuple]]]:
return self.DG.pruning_history()
def load_pruning_history(self, pruning_history) -> None:
self.DG.load_pruning_history(pruning_history)
def get_target_sparsity(self, module) -> float:
s = self.ch_sparsity_dict.get(module, self.per_step_ch_sparsity)[self.current_step]
return min(s, self.max_ch_sparsity)
def reset(self) -> None:
self.current_step = 0
def regularize(self, model, loss) -> typing.Any:
""" Model regularizor for sparse training
"""
pass
def _check_sparsity(self, group) -> bool:
for dep, _ in group:
module = dep.target.module
pruning_fn = dep.handler
if dep.target.type == ops.OPTYPE.PARAMETER:
continue
if self.DG.is_out_channel_pruning_fn(pruning_fn):
layer_out_ch = self.DG.get_out_channels(module)
if layer_out_ch is None: continue
if layer_out_ch < self.layer_init_out_ch[module] * (
1 - self.max_ch_sparsity
) or layer_out_ch == 1:
return False
elif self.DG.is_in_channel_pruning_fn(pruning_fn):
layer_in_ch = self.DG.get_in_channels(module)
if layer_in_ch is None: continue
if layer_in_ch < self.layer_init_in_ch[module] * (
1 - self.max_ch_sparsity
) or layer_in_ch == 1:
return False
return True
def _get_channel_groups(self, group) -> int:
ch_groups = 1
has_unbind = False
unbind_node = None
for dep, _ in group:
module = dep.target.module
pruning_fn = dep.handler
channel_groups = self.out_channel_groups if self.DG.is_out_channel_pruning_fn(pruning_fn) else self.in_channel_groups
if module in channel_groups:
ch_groups = channel_groups[module]
if dep.source.type==ops.OPTYPE.UNBIND:
has_unbind = True
unbind_node = dep.source
if has_unbind and ch_groups>1:
ch_groups = ch_groups // len(unbind_node.outputs)
return ch_groups # no channel grouping
def _downstream_node_as_root_if_unbind(self, group):
# Use a downstream node as the root if torch.unbind exists. TODO: find a general way to handle torch.unbind in timm
qkv_unbind = False
downstream_dep = None
for _dep, _idxs in group:
if _dep.source.type == ops.OPTYPE.UNBIND:
qkv_unbind = True
if isinstance(_dep.target.module, tuple(self.root_module_types)):
downstream_dep = _dep
if qkv_unbind and downstream_dep is not None: # use a downstream node as the root node
group = self.DG.get_pruning_group(downstream_dep.target.module, downstream_dep.handler, _idxs)
return group
def _round_to(self, n_pruned, current_channels, round_to):
rounded_channels = current_channels - n_pruned
rounded_channels = rounded_channels + (round_to - rounded_channels % round_to)
n_pruned = current_channels - rounded_channels
return n_pruned
def prune_local(self) -> typing.Generator:
if self.current_step > self.iterative_steps:
warnings.warn("Pruning exceed the maximum iterative steps, no pruning will be performed.")
return
for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):
if self._check_sparsity(group): # check pruning ratio
group = self._downstream_node_as_root_if_unbind(group)
module = group[0][0].target.module
pruning_fn = group[0][0].handler
ch_groups = self._get_channel_groups(group)
imp = self.estimate_importance(group, ch_groups=ch_groups)
if imp is None: continue
if self.DG.is_out_channel_pruning_fn(pruning_fn):
current_channels = self.DG.get_out_channels(module)
target_sparsity = self.get_target_sparsity(module)
n_pruned = current_channels - int(
self.layer_init_out_ch[module] *
(1 - target_sparsity)
)
else:
current_channels = self.DG.get_in_channels(module)
target_sparsity = self.get_target_sparsity(module)
n_pruned = current_channels - int(
self.layer_init_in_ch[module] *
(1 - target_sparsity)
)
# round to the nearest multiple of round_to
if self.round_to:
n_pruned = self._round_to(n_pruned, current_channels, self.round_to)
if n_pruned <= 0:
continue
if ch_groups > 1: # independent pruning for each group
group_size = current_channels // ch_groups
pruning_idxs = []
n_pruned_per_group = n_pruned // ch_groups # max(1, n_pruned // ch_groups)
if n_pruned_per_group == 0: continue # skip
for chg in range(ch_groups):
sub_group_imp = imp[chg*group_size: (chg+1)*group_size]
sub_imp_argsort = torch.argsort(sub_group_imp)
sub_pruning_idxs = sub_imp_argsort[:n_pruned_per_group] + chg*group_size # offset
pruning_idxs.append(sub_pruning_idxs)
pruning_idxs = torch.cat(pruning_idxs, 0)
else: # no channel grouping
imp_argsort = torch.argsort(imp)
pruning_idxs = imp_argsort[:n_pruned]
group = self.DG.get_pruning_group(
module, pruning_fn, pruning_idxs.tolist())
if self.DG.check_pruning_group(group):
yield group
def prune_global(self) -> typing.Generator:
if self.current_step > self.iterative_steps:
warnings.warn("Pruning exceed the maximum iterative steps, no pruning will be performed.")
return
# Pre-compute importance for each group
global_importance = []
for group in self.DG.get_all_groups(ignored_layers=self.ignored_layers, root_module_types=self.root_module_types):
if self._check_sparsity(group):
group = self._downstream_node_as_root_if_unbind(group)
ch_groups = self._get_channel_groups(group)
imp = self.estimate_importance(group, ch_groups=ch_groups)
if imp is None: continue
if ch_groups > 1:
imp = imp.view(ch_groups, -1).mean(dim=0) # average importance across groups. TODO: find a better way to handle grouped channels
global_importance.append((group, ch_groups, imp))
if len(global_importance) == 0:
return
# Find the threshold for global pruning
imp = torch.cat([local_imp[-1] for local_imp in global_importance], dim=0)
target_sparsity = self.per_step_ch_sparsity[self.current_step]
n_pruned = len(imp) - int(
self.initial_total_channels *
(1 - target_sparsity)
)
if n_pruned <= 0:
return
topk_imp, _ = torch.topk(imp, k=n_pruned, largest=False)
thres = topk_imp[-1]
# Group-by-group pruning
for group, ch_groups, imp in global_importance:
module = group[0][0].target.module
pruning_fn = group[0][0].handler
get_channel_fn = self.DG.get_out_channels if self.DG.is_out_channel_pruning_fn(pruning_fn) else self.DG.get_in_channels
pruning_indices = (imp <= thres).nonzero().view(-1)
if ch_groups > 1: # re-compute importance for each channel group if channel grouping is enabled
n_pruned_per_group = len(pruning_indices)
if n_pruned_per_group == 0: continue # skip
if self.round_to:
n_pruned = n_pruned_per_group * ch_groups
current_channels = get_channel_fn(module)
n_pruned = self._round_to(n_pruned, current_channels, self.round_to)
n_pruned_per_group = n_pruned // ch_groups
imp = self.estimate_importance(group, ch_groups=ch_groups) # re-compute importance
group_size = len(imp) // ch_groups
pruning_indices = []
for chg in range(ch_groups): # determine pruning indices for each channel group independently
sub_group_imp = imp[chg*group_size: (chg+1)*group_size]
sub_imp_argsort = torch.argsort(sub_group_imp)
sub_pruning_idxs = sub_imp_argsort[:n_pruned_per_group]+chg*group_size
pruning_indices.append(sub_pruning_idxs)
pruning_indices = torch.cat(pruning_indices, 0)
else:
if self.round_to:
n_pruned = len(pruning_indices)
current_channels = get_channel_fn(module)
n_pruned = self._round_to(n_pruned, current_channels, self.round_to)
pruning_indices = pruning_indices[:n_pruned]
group = self.DG.get_pruning_group(
module, pruning_fn, pruning_indices.tolist())
if self.DG.check_pruning_group(group):
yield group