Skip to content

Commit

Permalink
Merge pull request #39 from graphcore-research/refactor-encode-round
Browse files Browse the repository at this point in the history
Refactor encode/round
  • Loading branch information
awf authored Sep 19, 2024
2 parents 0257255 + 1d20838 commit 564c104
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 174 deletions.
6 changes: 4 additions & 2 deletions src/gfloat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
)
from .decode import decode_float
from .printing import float_pow2str, float_tilde_unless_roundtrip_str
from .round import encode_float, round_float
from .round_ndarray import encode_ndarray, round_ndarray
from .round import round_float
from .encode import encode_float
from .round_ndarray import round_ndarray
from .encode_ndarray import encode_ndarray
from .decode_ndarray import decode_ndarray
from .types import FloatClass, FloatValue, FormatInfo, RoundMode

Expand Down
3 changes: 2 additions & 1 deletion src/gfloat/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import numpy.typing as npt

from .decode import decode_float
from .round import RoundMode, encode_float, round_float
from .round import RoundMode, round_float
from .encode import encode_float
from .types import FormatInfo


Expand Down
98 changes: 98 additions & 0 deletions src/gfloat/encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.

import math

import numpy as np

from .types import FormatInfo


def encode_float(fi: FormatInfo, v: float) -> int:
"""
Encode input to the given :py:class:`FormatInfo`.
Will round toward zero if :paramref:`v` is not in the value set.
Will saturate to `Inf`, `NaN`, `fi.max` in order of precedence.
Encode -0 to 0 if not `fi.has_nz`
For other roundings and saturations, call :func:`round_float` first.
Args:
fi (FormatInfo): Describes the target format
v (float): The value to be encoded.
Returns:
The integer code point
"""

# Format Constants
k = fi.bits
p = fi.precision
t = p - 1

# Encode
if np.isnan(v):
return fi.code_of_nan

# Overflow/underflow
if v > fi.max:
if fi.has_infs:
return fi.code_of_posinf
if fi.num_nans > 0:
return fi.code_of_nan
return fi.code_of_max

if v < fi.min:
if fi.has_infs:
return fi.code_of_neginf
if fi.num_nans > 0:
return fi.code_of_nan
return fi.code_of_min

# Finite values
sign = fi.is_signed and np.signbit(v)
vpos = -v if sign else v

if fi.has_subnormals and vpos <= fi.smallest_subnormal / 2:
isig = 0
biased_exp = 0
else:
sig, exp = np.frexp(vpos)
exp = int(exp) # All calculations in Python ints

# sig in range [0.5, 1)
sig *= 2
exp -= 1
# now sig in range [1, 2)

biased_exp = exp + fi.expBias
if biased_exp < 1 and fi.has_subnormals:
# subnormal
sig *= 2.0 ** (biased_exp - 1)
biased_exp = 0
assert vpos == sig * 2 ** (1 - fi.expBias)
else:
if sig > 0:
sig -= 1.0

isig = math.floor(sig * 2**t)

# Zero
if isig == 0 and biased_exp == 0 and fi.has_zero:
if sign and fi.has_nz:
return fi.code_of_negzero
else:
return fi.code_of_zero

# Nonzero
assert isig < 2**t
assert biased_exp < 2**fi.expBits or fi.is_twos_complement

# Handle two's complement encoding
if fi.is_twos_complement and sign:
isig = (1 << t) - isig

# Pack values into a single integer
code = (int(sign) << (k - 1)) | (biased_exp << t) | (isig << 0)

return code
84 changes: 84 additions & 0 deletions src/gfloat/encode_ndarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2024 Graphcore Ltd. All rights reserved.

from .types import FormatInfo
import numpy as np


def encode_ndarray(fi: FormatInfo, v: np.ndarray) -> np.ndarray:
"""
Vectorized version of :meth:`encode_float`.
Encode inputs to the given :py:class:`FormatInfo`.
Will round toward zero if :paramref:`v` is not in the value set.
Will saturate to `Inf`, `NaN`, `fi.max` in order of precedence.
Encode -0 to 0 if not `fi.has_nz`
For other roundings and saturations, call :func:`round_ndarray` first.
Args:
fi (FormatInfo): Describes the target format
v (float array): The value to be encoded.
Returns:
The integer code point
"""
k = fi.bits
p = fi.precision
t = p - 1

sign = np.signbit(v) & fi.is_signed
vpos = np.where(sign, -v, v)

nan_mask = np.isnan(v)

code = np.zeros_like(v, dtype=np.uint64)

if fi.num_nans > 0:
code[nan_mask] = fi.code_of_nan
else:
assert not np.any(nan_mask)

if fi.has_infs:
code[v > fi.max] = fi.code_of_posinf
code[v < fi.min] = fi.code_of_neginf
else:
code[v > fi.max] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_max
code[v < fi.min] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_min

if fi.has_zero:
if fi.has_nz:
code[v == 0] = np.where(sign[v == 0], fi.code_of_negzero, fi.code_of_zero)
else:
code[v == 0] = fi.code_of_zero

finite_mask = (code == 0) & (v != 0)
assert not np.any(np.isnan(vpos[finite_mask]))
if np.any(finite_mask):
finite_vpos = vpos[finite_mask]
finite_sign = sign[finite_mask]

sig, exp = np.frexp(finite_vpos)

biased_exp = exp.astype(np.int64) + (fi.expBias - 1)
subnormal_mask = (biased_exp < 1) & fi.has_subnormals

biased_exp_safe = np.where(subnormal_mask, biased_exp, 0)
tsig = np.where(subnormal_mask, np.ldexp(sig, biased_exp_safe), sig * 2 - 1.0)
biased_exp[subnormal_mask] = 0

isig = np.floor(np.ldexp(tsig, t)).astype(np.int64)

zero_mask = fi.has_zero & (isig == 0) & (biased_exp == 0)
if not fi.has_nz:
finite_sign[zero_mask] = False

# Handle two's complement encoding
if fi.is_twos_complement:
isig[finite_sign] = (1 << t) - isig[finite_sign]

code[finite_mask] = (
(finite_sign.astype(int) << (k - 1)) | (biased_exp << t) | (isig << 0)
)

return code
91 changes: 0 additions & 91 deletions src/gfloat/round.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,94 +166,3 @@ def round_float(
result = -result

return result


def encode_float(fi: FormatInfo, v: float) -> int:
"""
Encode input to the given :py:class:`FormatInfo`.
Will round toward zero if :paramref:`v` is not in the value set.
Will saturate to `Inf`, `NaN`, `fi.max` in order of precedence.
Encode -0 to 0 if not `fi.has_nz`
For other roundings and saturations, call :func:`round_float` first.
Args:
fi (FormatInfo): Describes the target format
v (float): The value to be encoded.
Returns:
The integer code point
"""

# Format Constants
k = fi.bits
p = fi.precision
t = p - 1

# Encode
if np.isnan(v):
return fi.code_of_nan

# Overflow/underflow
if v > fi.max:
if fi.has_infs:
return fi.code_of_posinf
if fi.num_nans > 0:
return fi.code_of_nan
return fi.code_of_max

if v < fi.min:
if fi.has_infs:
return fi.code_of_neginf
if fi.num_nans > 0:
return fi.code_of_nan
return fi.code_of_min

# Finite values
sign = fi.is_signed and np.signbit(v)
vpos = -v if sign else v

if fi.has_subnormals and vpos <= fi.smallest_subnormal / 2:
isig = 0
biased_exp = 0
else:
sig, exp = np.frexp(vpos)
exp = int(exp) # All calculations in Python ints

# sig in range [0.5, 1)
sig *= 2
exp -= 1
# now sig in range [1, 2)

biased_exp = exp + fi.expBias
if biased_exp < 1 and fi.has_subnormals:
# subnormal
sig *= 2.0 ** (biased_exp - 1)
biased_exp = 0
assert vpos == sig * 2 ** (1 - fi.expBias)
else:
if sig > 0:
sig -= 1.0

isig = math.floor(sig * 2**t)

# Zero
if isig == 0 and biased_exp == 0 and fi.has_zero:
if sign and fi.has_nz:
return fi.code_of_negzero
else:
return fi.code_of_zero

# Nonzero
assert isig < 2**t
assert biased_exp < 2**fi.expBits or fi.is_twos_complement

# Handle two's complement encoding
if fi.is_twos_complement and sign:
isig = (1 << t) - isig

# Pack values into a single integer
code = (int(sign) << (k - 1)) | (biased_exp << t) | (isig << 0)

return code
80 changes: 0 additions & 80 deletions src/gfloat/round_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,83 +150,3 @@ def round_ndarray(
result = np.where(result == 0, 0.0, result)

return result


def encode_ndarray(fi: FormatInfo, v: np.ndarray) -> np.ndarray:
"""
Vectorized version of :meth:`encode_float`.
Encode inputs to the given :py:class:`FormatInfo`.
Will round toward zero if :paramref:`v` is not in the value set.
Will saturate to `Inf`, `NaN`, `fi.max` in order of precedence.
Encode -0 to 0 if not `fi.has_nz`
For other roundings and saturations, call :func:`round_ndarray` first.
Args:
fi (FormatInfo): Describes the target format
v (float array): The value to be encoded.
Returns:
The integer code point
"""
k = fi.bits
p = fi.precision
t = p - 1

sign = np.signbit(v) & fi.is_signed
vpos = np.where(sign, -v, v)

nan_mask = np.isnan(v)

code = np.zeros_like(v, dtype=np.uint64)

if fi.num_nans > 0:
code[nan_mask] = fi.code_of_nan
else:
assert not np.any(nan_mask)

if fi.has_infs:
code[v > fi.max] = fi.code_of_posinf
code[v < fi.min] = fi.code_of_neginf
else:
code[v > fi.max] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_max
code[v < fi.min] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_min

if fi.has_zero:
if fi.has_nz:
code[v == 0] = np.where(sign[v == 0], fi.code_of_negzero, fi.code_of_zero)
else:
code[v == 0] = fi.code_of_zero

finite_mask = (code == 0) & (v != 0)
assert not np.any(np.isnan(vpos[finite_mask]))
if np.any(finite_mask):
finite_vpos = vpos[finite_mask]
finite_sign = sign[finite_mask]

sig, exp = np.frexp(finite_vpos)

biased_exp = exp.astype(np.int64) + (fi.expBias - 1)
subnormal_mask = (biased_exp < 1) & fi.has_subnormals

biased_exp_safe = np.where(subnormal_mask, biased_exp, 0)
tsig = np.where(subnormal_mask, np.ldexp(sig, biased_exp_safe), sig * 2 - 1.0)
biased_exp[subnormal_mask] = 0

isig = np.floor(np.ldexp(tsig, t)).astype(np.int64)

zero_mask = fi.has_zero & (isig == 0) & (biased_exp == 0)
if not fi.has_nz:
finite_sign[zero_mask] = False

# Handle two's complement encoding
if fi.is_twos_complement:
isig[finite_sign] = (1 << t) - isig[finite_sign]

code[finite_mask] = (
(finite_sign.astype(int) << (k - 1)) | (biased_exp << t) | (isig << 0)
)

return code

0 comments on commit 564c104

Please sign in to comment.