-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #39 from graphcore-research/refactor-encode-round
Refactor encode/round
- Loading branch information
Showing
6 changed files
with
188 additions
and
174 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright (c) 2024 Graphcore Ltd. All rights reserved. | ||
|
||
import math | ||
|
||
import numpy as np | ||
|
||
from .types import FormatInfo | ||
|
||
|
||
def encode_float(fi: FormatInfo, v: float) -> int: | ||
""" | ||
Encode input to the given :py:class:`FormatInfo`. | ||
Will round toward zero if :paramref:`v` is not in the value set. | ||
Will saturate to `Inf`, `NaN`, `fi.max` in order of precedence. | ||
Encode -0 to 0 if not `fi.has_nz` | ||
For other roundings and saturations, call :func:`round_float` first. | ||
Args: | ||
fi (FormatInfo): Describes the target format | ||
v (float): The value to be encoded. | ||
Returns: | ||
The integer code point | ||
""" | ||
|
||
# Format Constants | ||
k = fi.bits | ||
p = fi.precision | ||
t = p - 1 | ||
|
||
# Encode | ||
if np.isnan(v): | ||
return fi.code_of_nan | ||
|
||
# Overflow/underflow | ||
if v > fi.max: | ||
if fi.has_infs: | ||
return fi.code_of_posinf | ||
if fi.num_nans > 0: | ||
return fi.code_of_nan | ||
return fi.code_of_max | ||
|
||
if v < fi.min: | ||
if fi.has_infs: | ||
return fi.code_of_neginf | ||
if fi.num_nans > 0: | ||
return fi.code_of_nan | ||
return fi.code_of_min | ||
|
||
# Finite values | ||
sign = fi.is_signed and np.signbit(v) | ||
vpos = -v if sign else v | ||
|
||
if fi.has_subnormals and vpos <= fi.smallest_subnormal / 2: | ||
isig = 0 | ||
biased_exp = 0 | ||
else: | ||
sig, exp = np.frexp(vpos) | ||
exp = int(exp) # All calculations in Python ints | ||
|
||
# sig in range [0.5, 1) | ||
sig *= 2 | ||
exp -= 1 | ||
# now sig in range [1, 2) | ||
|
||
biased_exp = exp + fi.expBias | ||
if biased_exp < 1 and fi.has_subnormals: | ||
# subnormal | ||
sig *= 2.0 ** (biased_exp - 1) | ||
biased_exp = 0 | ||
assert vpos == sig * 2 ** (1 - fi.expBias) | ||
else: | ||
if sig > 0: | ||
sig -= 1.0 | ||
|
||
isig = math.floor(sig * 2**t) | ||
|
||
# Zero | ||
if isig == 0 and biased_exp == 0 and fi.has_zero: | ||
if sign and fi.has_nz: | ||
return fi.code_of_negzero | ||
else: | ||
return fi.code_of_zero | ||
|
||
# Nonzero | ||
assert isig < 2**t | ||
assert biased_exp < 2**fi.expBits or fi.is_twos_complement | ||
|
||
# Handle two's complement encoding | ||
if fi.is_twos_complement and sign: | ||
isig = (1 << t) - isig | ||
|
||
# Pack values into a single integer | ||
code = (int(sign) << (k - 1)) | (biased_exp << t) | (isig << 0) | ||
|
||
return code |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright (c) 2024 Graphcore Ltd. All rights reserved. | ||
|
||
from .types import FormatInfo | ||
import numpy as np | ||
|
||
|
||
def encode_ndarray(fi: FormatInfo, v: np.ndarray) -> np.ndarray: | ||
""" | ||
Vectorized version of :meth:`encode_float`. | ||
Encode inputs to the given :py:class:`FormatInfo`. | ||
Will round toward zero if :paramref:`v` is not in the value set. | ||
Will saturate to `Inf`, `NaN`, `fi.max` in order of precedence. | ||
Encode -0 to 0 if not `fi.has_nz` | ||
For other roundings and saturations, call :func:`round_ndarray` first. | ||
Args: | ||
fi (FormatInfo): Describes the target format | ||
v (float array): The value to be encoded. | ||
Returns: | ||
The integer code point | ||
""" | ||
k = fi.bits | ||
p = fi.precision | ||
t = p - 1 | ||
|
||
sign = np.signbit(v) & fi.is_signed | ||
vpos = np.where(sign, -v, v) | ||
|
||
nan_mask = np.isnan(v) | ||
|
||
code = np.zeros_like(v, dtype=np.uint64) | ||
|
||
if fi.num_nans > 0: | ||
code[nan_mask] = fi.code_of_nan | ||
else: | ||
assert not np.any(nan_mask) | ||
|
||
if fi.has_infs: | ||
code[v > fi.max] = fi.code_of_posinf | ||
code[v < fi.min] = fi.code_of_neginf | ||
else: | ||
code[v > fi.max] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_max | ||
code[v < fi.min] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_min | ||
|
||
if fi.has_zero: | ||
if fi.has_nz: | ||
code[v == 0] = np.where(sign[v == 0], fi.code_of_negzero, fi.code_of_zero) | ||
else: | ||
code[v == 0] = fi.code_of_zero | ||
|
||
finite_mask = (code == 0) & (v != 0) | ||
assert not np.any(np.isnan(vpos[finite_mask])) | ||
if np.any(finite_mask): | ||
finite_vpos = vpos[finite_mask] | ||
finite_sign = sign[finite_mask] | ||
|
||
sig, exp = np.frexp(finite_vpos) | ||
|
||
biased_exp = exp.astype(np.int64) + (fi.expBias - 1) | ||
subnormal_mask = (biased_exp < 1) & fi.has_subnormals | ||
|
||
biased_exp_safe = np.where(subnormal_mask, biased_exp, 0) | ||
tsig = np.where(subnormal_mask, np.ldexp(sig, biased_exp_safe), sig * 2 - 1.0) | ||
biased_exp[subnormal_mask] = 0 | ||
|
||
isig = np.floor(np.ldexp(tsig, t)).astype(np.int64) | ||
|
||
zero_mask = fi.has_zero & (isig == 0) & (biased_exp == 0) | ||
if not fi.has_nz: | ||
finite_sign[zero_mask] = False | ||
|
||
# Handle two's complement encoding | ||
if fi.is_twos_complement: | ||
isig[finite_sign] = (1 << t) - isig[finite_sign] | ||
|
||
code[finite_mask] = ( | ||
(finite_sign.astype(int) << (k - 1)) | (biased_exp << t) | (isig << 0) | ||
) | ||
|
||
return code |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters