-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add FP ULP comparisons and inspector tool
- Loading branch information
Showing
4 changed files
with
235 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
# SPDX-FileCopyrightText: 2023-present Eric T. Johnson | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
from yut23_utils.fp import FloatInspector, compare_ulp, ulp_diff |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
# SPDX-FileCopyrightText: 2023-present Eric T. Johnson | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
"""Tools for working with floating-point numbers.""" | ||
|
||
import math | ||
import struct | ||
from dataclasses import dataclass, field | ||
from functools import cached_property | ||
from typing import cast | ||
|
||
|
||
def float_to_int(x: float, /) -> int: | ||
return cast(int, struct.unpack("<Q", struct.pack("<d", x))[0]) | ||
|
||
|
||
def int_to_float(q: int, /) -> float: | ||
return cast(float, struct.unpack("<d", struct.pack("<Q", q))[0]) | ||
|
||
|
||
def ulp_diff(a: float, b: float) -> int: | ||
"""Return the (signed) number of representable FP64 values in the range [a, b).""" | ||
if not math.isfinite(a) or not math.isfinite(b): | ||
raise ValueError("only finite values can be compared") | ||
if a == b: | ||
return 0 | ||
if a > b: | ||
# pylint: disable-next=arguments-out-of-order | ||
return -ulp_diff(b, a) | ||
if math.copysign(1.0, a) != math.copysign(1.0, b): | ||
# different signs: split the interval at zero | ||
return ulp_diff(a, -0.0) + ulp_diff(0.0, b) | ||
return abs(float_to_int(a) - float_to_int(b)) | ||
|
||
|
||
def compare_ulp(a: float, b: float, /, ulps: int) -> bool: | ||
"""Check if two numbers match to within a specified number of FP64 ULPs.""" | ||
if ulps < 0: | ||
raise ValueError("ulps must be non-negative") | ||
return abs(ulp_diff(a, b)) <= ulps | ||
|
||
|
||
@dataclass(frozen=True) | ||
class FloatInspector: | ||
float_val: float | ||
int_val: int = field(init=False) | ||
|
||
SIGN_MASK = 0x8000000000000000 | ||
EXP_MASK = 0x7FF0000000000000 | ||
MANT_MASK = 0x000FFFFFFFFFFFFF | ||
|
||
def __post_init__(self): | ||
object.__setattr__(self, "int_val", float_to_int(self.float_val)) | ||
|
||
@cached_property | ||
def raw_sign(self) -> int: | ||
return (self.int_val & self.SIGN_MASK) >> 63 | ||
|
||
@cached_property | ||
def raw_exponent(self) -> int: | ||
return (self.int_val & self.EXP_MASK) >> 52 | ||
|
||
@cached_property | ||
def raw_mantissa(self) -> int: | ||
return self.int_val & self.MANT_MASK | ||
|
||
def __str__(self) -> str: | ||
return ( | ||
f"FloatData({float(self)} = " | ||
f"{self.raw_sign} * 2^{self.exponent} * {self.mantissa}; " | ||
f"s={self.raw_sign}, e={self.raw_exponent:011b}, m={self.raw_mantissa:052b})" | ||
) | ||
|
||
def __repr__(self) -> str: | ||
f = float(self) | ||
return f"FloatData({f}; {f.hex()})" | ||
|
||
def __float__(self) -> float: | ||
return self.float_val | ||
|
||
@cached_property | ||
def exponent(self) -> int: | ||
return self.raw_exponent - 1023 | ||
|
||
@cached_property | ||
def mantissa(self) -> float: | ||
frac = self.raw_mantissa / (1 << 52) | ||
if self.raw_exponent == 0: | ||
# zero and subnormals | ||
frac *= 2 | ||
else: | ||
frac += 1 | ||
if self.raw_sign == 1: | ||
frac *= -1 | ||
return frac | ||
|
||
def is_negative(self) -> bool: | ||
return bool(self.raw_sign) | ||
|
||
def is_inf(self) -> bool: | ||
return self.raw_exponent == 0x7FF and self.raw_mantissa == 0 | ||
|
||
def is_nan(self) -> bool: | ||
return self.raw_exponent == 0x7FF and self.raw_mantissa != 0 | ||
|
||
def is_subnormal(self) -> bool: | ||
return self.raw_exponent == 0 and self.raw_mantissa != 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# SPDX-FileCopyrightText: 2023-present Eric T. Johnson | ||
# | ||
# SPDX-License-Identifier: BSD-3-Clause | ||
|
||
import math | ||
import sys | ||
|
||
import pytest | ||
from hypothesis import assume, example, given, note | ||
from hypothesis import strategies as st | ||
|
||
from yut23_utils.fp import ( | ||
FloatInspector, | ||
compare_ulp, | ||
float_to_int, | ||
int_to_float, | ||
ulp_diff, | ||
) | ||
|
||
|
||
@given(st.floats()) | ||
def test_float_to_int(x_flt: float) -> None: | ||
x_int = float_to_int(x_flt) | ||
assert isinstance(x_int, int) | ||
assert 0 <= x_int < 1 << 64 | ||
x_flt_2 = int_to_float(x_int) | ||
if math.isnan(x_flt): | ||
assert math.isnan(x_flt_2) | ||
else: | ||
assert x_flt_2 == x_flt | ||
|
||
|
||
@given(st.integers(min_value=0, max_value=(1 << 64) - 1)) | ||
def test_int_to_float(x_int: int) -> None: | ||
x_flt = int_to_float(x_int) | ||
assert float_to_int(x_flt) == x_int | ||
|
||
|
||
@st.composite | ||
def float_pairs(draw: st.DrawFn, max_ulps: int) -> tuple[float, float, int]: | ||
a = draw(st.floats(allow_nan=False, allow_infinity=False)) | ||
ulps = draw(st.integers(min_value=-max_ulps, max_value=max_ulps)) | ||
|
||
b = a | ||
direction = math.copysign(math.inf, ulps) | ||
for _ in range(abs(ulps)): | ||
b = math.nextafter(b, direction) | ||
assume(not math.isinf(b)) | ||
|
||
note(f"a = {FloatInspector(a)}\nb = {FloatInspector(b)}") | ||
return a, b, ulps | ||
|
||
|
||
@given(float_pairs(max_ulps=100)) | ||
@example((0.0, -5e-324, -1)) | ||
@example((5e-324, 0.0, -1)) | ||
@example((5e-324, -0.0, -1)) | ||
@example((-5e-324, 5e-324, 2)) | ||
def test_ulp_diff(args: tuple[float, float, int]) -> None: | ||
a, b, ulps = args | ||
assert ulp_diff(a, b) == ulps | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"a,b", | ||
[ | ||
(math.inf, 3.14), | ||
(math.inf, math.inf), | ||
(-math.inf, math.inf), | ||
(math.nan, -2.718), | ||
(math.nan, math.nan), | ||
], | ||
) | ||
def test_ulp_diff_errors(a: float, b: float) -> None: | ||
with pytest.raises(ValueError): | ||
ulp_diff(a, b) | ||
with pytest.raises(ValueError): | ||
ulp_diff(b, a) | ||
|
||
|
||
@given(float_pairs(max_ulps=100)) | ||
@example((-5e-324, 5e-324, 2)) | ||
def test_compare_ulp(args: tuple[float, float, int]) -> None: | ||
a, b, ulps = args | ||
assert compare_ulp(a, b, abs(ulps)) | ||
if abs(ulps) > 0: | ||
assert not compare_ulp(a, b, abs(ulps) - 1) | ||
assert compare_ulp(a, b, abs(ulps) + 1) | ||
|
||
|
||
def test_compare_ulp_errors() -> None: | ||
with pytest.raises(ValueError): | ||
compare_ulp(1.0, 1.0, -1) | ||
|
||
|
||
@given(st.floats()) | ||
def test_FloatInspector(x: float) -> None: | ||
fi = FloatInspector(x) | ||
note(f"fi={fi} (subnormal: {fi.is_subnormal()})") | ||
if not math.isnan(x): | ||
assert float(fi) == x | ||
# convert to an int to check raw NaN representation | ||
assert float_to_int(float(fi)) == float_to_int(x) | ||
|
||
# try manually reconstructing the value from the mantissa and exponent | ||
if math.isfinite(x): | ||
reconstructed = math.ldexp(fi.mantissa, fi.exponent) | ||
assert reconstructed == x | ||
|
||
# check interrogator methods | ||
# need to use copysign rather than `x < 0.0` to handle -0.0 properly | ||
is_negative = math.copysign(1.0, x) == -1.0 | ||
assert fi.is_negative() == is_negative | ||
assert fi.is_inf() == math.isinf(x) | ||
assert fi.is_nan() == math.isnan(x) | ||
is_subnormal = 0 < abs(x) < sys.float_info.min | ||
assert fi.is_subnormal() == is_subnormal |