-
Notifications
You must be signed in to change notification settings - Fork 103
/
Copy pathexample_ternary.py
112 lines (93 loc) · 3.39 KB
/
example_ternary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright 2020 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import # Not necessary in a Python 3-only module
from __future__ import division # Not necessary in a Python 3-only module
from __future__ import print_function # Not necessary in a Python 3-only module
from absl import app
from absl import flags
import matplotlib
import numpy as np
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
FLAGS = flags.FLAGS
def _stochastic_rounding(x, precision, resolution, delta):
"""Stochastic_rounding for ternary.
Args:
x:
precision: A float. The area we want to make this stochastic rounding.
[delta-precision, delta] [delta, delta+precision]
resolution: control the quantization resolution.
delta: the undiscountinued point (positive number)
Return:
A tensor with stochastic rounding numbers.
"""
delta_left = delta - precision
delta_right = delta + precision
scale = 1 / resolution
scale_delta_left = delta_left * scale
scale_delta_right = delta_right * scale
scale_2_delta = scale_delta_right - scale_delta_left
scale_x = x * scale
fraction = scale_x - scale_delta_left
# print(precision, scale, x[0], np.floor(scale_x[0]), scale_x[0], fraction[0])
# we use uniform distribution
random_selector = np.random.uniform(0, 1, size=x.shape) * scale_2_delta
# print(precision, scale, x[0], delta_left[0], delta_right[0])
# print('x', scale_x[0], fraction[0], random_selector[0], scale_2_delta[0])
# rounddown = fraction < random_selector
result = np.where(fraction < random_selector,
scale_delta_left / scale,
scale_delta_right / scale)
return result
def _ternary(x, sto=False):
m = np.amax(np.abs(x), keepdims=True)
scale = 2 * m / 3.0
thres = scale / 2.0
ratio = 0.1
if sto:
sign_bit = np.sign(x)
x = np.abs(x)
prec = x / scale
x = (
sign_bit * scale * _stochastic_rounding(
x / scale,
precision=0.3, resolution=0.01, # those two are all normalized.
delta=thres / scale))
# prec + prec *ratio)
# mm = np.amax(np.abs(x), keepdims=True)
return np.where(np.abs(x) < thres, np.zeros_like(x), np.sign(x))
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
# x = np.arange(-3.0, 3.0, 0.01)
# x = np.random.uniform(-0.01, 0.01, size=1000)
x = np.random.uniform(-10.0, 10.0, size=1000)
# x = np.random.uniform(-1, 1, size=1000)
x = np.sort(x)
tr = np.zeros_like(x)
t = np.zeros_like(x)
iter_count = 500
for _ in range(iter_count):
y = _ternary(x)
yr = _ternary(x, sto=True)
t = t + y
tr = tr + yr
plt.plot(x, t/iter_count)
plt.plot(x, tr/iter_count)
plt.ylabel('mean (%s samples)' % iter_count)
plt.show()
if __name__ == '__main__':
app.run(main)