forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_model_dump.py
114 lines (92 loc) · 3.49 KB
/
test_model_dump.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python3
import sys
import io
import unittest
import torch
import torch.utils.model_dump
import torch.utils.mobile_optimizer
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_quantized import supported_qengines
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(16, 64)
self.relu1 = torch.nn.ReLU()
self.layer2 = torch.nn.Linear(64, 8)
self.relu2 = torch.nn.ReLU()
def forward(self, features):
act = features
act = self.layer1(act)
act = self.relu1(act)
act = self.layer2(act)
act = self.relu2(act)
return act
class QuantModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.quant = torch.quantization.QuantStub()
self.dequant = torch.quantization.DeQuantStub()
self.core = SimpleModel()
def forward(self, x):
x = self.quant(x)
x = self.core(x)
x = self.dequant(x)
return x
class ModelWithLists(torch.nn.Module):
def __init__(self):
super().__init__()
self.rt = [torch.zeros(1)]
self.ot = [torch.zeros(1), None]
def forward(self, arg):
arg = arg + self.rt[0]
o = self.ot[0]
if o is not None:
arg = arg + o
return arg
class TestModelDump(TestCase):
@unittest.skipIf(sys.version_info < (3, 7), "importlib.resources was new in 3.7")
def test_inline_skeleton(self):
skel = torch.utils.model_dump.get_inline_skeleton()
assert "unpkg.org" not in skel
assert "src=" not in skel
def do_dump_model(self, model, extra_files=None):
# Just check that we're able to run successfully.
buf = io.BytesIO()
torch.jit.save(model, buf, _extra_files=extra_files)
info = torch.utils.model_dump.get_model_info(buf)
assert info is not None
def test_scripted_model(self):
model = torch.jit.script(SimpleModel())
self.do_dump_model(model)
def test_traced_model(self):
model = torch.jit.trace(SimpleModel(), torch.zeros(2, 16))
self.do_dump_model(model)
def get_quant_model(self):
fmodel = QuantModel().eval()
fmodel = torch.quantization.fuse_modules(fmodel, [
["core.layer1", "core.relu1"],
["core.layer2", "core.relu2"],
])
fmodel.qconfig = torch.quantization.get_default_qconfig("qnnpack")
prepped = torch.quantization.prepare(fmodel)
prepped(torch.randn(2, 16))
qmodel = torch.quantization.convert(prepped)
return qmodel
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
def test_quantized_model(self):
qmodel = self.get_quant_model()
self.do_dump_model(torch.jit.script(qmodel))
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
def test_optimized_quantized_model(self):
qmodel = self.get_quant_model()
smodel = torch.jit.trace(qmodel, torch.zeros(2, 16))
omodel = torch.utils.mobile_optimizer.optimize_for_mobile(smodel)
self.do_dump_model(omodel)
def test_model_with_lists(self):
model = torch.jit.script(ModelWithLists())
self.do_dump_model(model)
def test_invalid_json(self):
model = torch.jit.script(SimpleModel())
self.do_dump_model(model, extra_files={"foo.json": "{"})
if __name__ == '__main__':
run_tests()