Skip to content

Commit

Permalink
Add FP ULP comparisons and inspector tool
Browse files Browse the repository at this point in the history
  • Loading branch information
yut23 committed Jan 9, 2024
1 parent 33d942c commit 48d29d9
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 2 deletions.
11 changes: 9 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ path = "src/yut23_utils/__about__.py"
dependencies = [
"coverage[toml]>=6.5",
"pytest",
"hypothesis",
]
[tool.hatch.envs.default.scripts]
test = "pytest {args:tests}"
Expand All @@ -60,6 +61,7 @@ dependencies = [
"black>=23.1.0",
"mypy>=1.0.0",
"ruff>=0.0.243",
"hypothesis",
]
[tool.hatch.envs.lint.scripts]
typing = "mypy --install-types --non-interactive {args:src/yut23_utils tests}"
Expand All @@ -79,12 +81,12 @@ all = [

[tool.black]
target-version = ["py38"]
line-length = 120
line-length = 88
skip-string-normalization = false

[tool.ruff]
target-version = "py38"
line-length = 120
line-length = 88
select = [
"A",
"ARG",
Expand Down Expand Up @@ -127,6 +129,9 @@ unfixable = [
"F401",
]

[tool.isort]
known_first_party = ["yut23_utils"]

[tool.ruff.isort]
known-first-party = ["yut23_utils"]

Expand All @@ -136,6 +141,8 @@ ban-relative-imports = "all"
[tool.ruff.per-file-ignores]
# Tests can use magic values, assertions, and relative imports
"tests/**/*" = ["PLR2004", "S101", "TID252"]
# __init__.py can import things without using them
"__init__.py" = ["F401"]

[tool.coverage.run]
source_pkgs = ["yut23_utils", "tests"]
Expand Down
2 changes: 2 additions & 0 deletions src/yut23_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# SPDX-FileCopyrightText: 2023-present Eric T. Johnson
#
# SPDX-License-Identifier: BSD-3-Clause

from yut23_utils.fp import FloatInspector, compare_ulp, ulp_diff
107 changes: 107 additions & 0 deletions src/yut23_utils/fp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# SPDX-FileCopyrightText: 2023-present Eric T. Johnson
#
# SPDX-License-Identifier: BSD-3-Clause
"""Tools for working with floating-point numbers."""

import math
import struct
from dataclasses import dataclass, field
from functools import cached_property
from typing import cast


def float_to_int(x: float, /) -> int:
return cast(int, struct.unpack("<Q", struct.pack("<d", x))[0])


def int_to_float(q: int, /) -> float:
return cast(float, struct.unpack("<d", struct.pack("<Q", q))[0])


def ulp_diff(a: float, b: float) -> int:
"""Return the (signed) number of representable FP64 values in the range [a, b)."""
if not math.isfinite(a) or not math.isfinite(b):
raise ValueError("only finite values can be compared")
if a == b:
return 0
if a > b:
# pylint: disable-next=arguments-out-of-order
return -ulp_diff(b, a)
if math.copysign(1.0, a) != math.copysign(1.0, b):
# different signs: split the interval at zero
return ulp_diff(a, -0.0) + ulp_diff(0.0, b)
return abs(float_to_int(a) - float_to_int(b))


def compare_ulp(a: float, b: float, /, ulps: int) -> bool:
"""Check if two numbers match to within a specified number of FP64 ULPs."""
if ulps < 0:
raise ValueError("ulps must be non-negative")
return abs(ulp_diff(a, b)) <= ulps


@dataclass(frozen=True)
class FloatInspector:
float_val: float
int_val: int = field(init=False)

SIGN_MASK = 0x8000000000000000
EXP_MASK = 0x7FF0000000000000
MANT_MASK = 0x000FFFFFFFFFFFFF

def __post_init__(self):
object.__setattr__(self, "int_val", float_to_int(self.float_val))

@cached_property
def raw_sign(self) -> int:
return (self.int_val & self.SIGN_MASK) >> 63

@cached_property
def raw_exponent(self) -> int:
return (self.int_val & self.EXP_MASK) >> 52

@cached_property
def raw_mantissa(self) -> int:
return self.int_val & self.MANT_MASK

def __str__(self) -> str:
return (
f"FloatData({float(self)} = "
f"{self.raw_sign} * 2^{self.exponent} * {self.mantissa}; "
f"s={self.raw_sign}, e={self.raw_exponent:011b}, m={self.raw_mantissa:052b})"
)

def __repr__(self) -> str:
f = float(self)
return f"FloatData({f}; {f.hex()})"

def __float__(self) -> float:
return self.float_val

@cached_property
def exponent(self) -> int:
return self.raw_exponent - 1023

@cached_property
def mantissa(self) -> float:
frac = self.raw_mantissa / (1 << 52)
if self.raw_exponent == 0:
# zero and subnormals
frac *= 2
else:
frac += 1
if self.raw_sign == 1:
frac *= -1
return frac

def is_negative(self) -> bool:
return bool(self.raw_sign)

def is_inf(self) -> bool:
return self.raw_exponent == 0x7FF and self.raw_mantissa == 0

def is_nan(self) -> bool:
return self.raw_exponent == 0x7FF and self.raw_mantissa != 0

def is_subnormal(self) -> bool:
return self.raw_exponent == 0 and self.raw_mantissa != 0
117 changes: 117 additions & 0 deletions tests/test_fp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# SPDX-FileCopyrightText: 2023-present Eric T. Johnson
#
# SPDX-License-Identifier: BSD-3-Clause

import math
import sys

import pytest
from hypothesis import assume, example, given, note
from hypothesis import strategies as st

from yut23_utils.fp import (
FloatInspector,
compare_ulp,
float_to_int,
int_to_float,
ulp_diff,
)


@given(st.floats())
def test_float_to_int(x_flt: float) -> None:
x_int = float_to_int(x_flt)
assert isinstance(x_int, int)
assert 0 <= x_int < 1 << 64
x_flt_2 = int_to_float(x_int)
if math.isnan(x_flt):
assert math.isnan(x_flt_2)
else:
assert x_flt_2 == x_flt


@given(st.integers(min_value=0, max_value=(1 << 64) - 1))
def test_int_to_float(x_int: int) -> None:
x_flt = int_to_float(x_int)
assert float_to_int(x_flt) == x_int


@st.composite
def float_pairs(draw: st.DrawFn, max_ulps: int) -> tuple[float, float, int]:
a = draw(st.floats(allow_nan=False, allow_infinity=False))
ulps = draw(st.integers(min_value=-max_ulps, max_value=max_ulps))

b = a
direction = math.copysign(math.inf, ulps)
for _ in range(abs(ulps)):
b = math.nextafter(b, direction)
assume(not math.isinf(b))

note(f"a = {FloatInspector(a)}\nb = {FloatInspector(b)}")
return a, b, ulps


@given(float_pairs(max_ulps=100))
@example((0.0, -5e-324, -1))
@example((5e-324, 0.0, -1))
@example((5e-324, -0.0, -1))
@example((-5e-324, 5e-324, 2))
def test_ulp_diff(args: tuple[float, float, int]) -> None:
a, b, ulps = args
assert ulp_diff(a, b) == ulps


@pytest.mark.parametrize(
"a,b",
[
(math.inf, 3.14),
(math.inf, math.inf),
(-math.inf, math.inf),
(math.nan, -2.718),
(math.nan, math.nan),
],
)
def test_ulp_diff_errors(a: float, b: float) -> None:
with pytest.raises(ValueError):
ulp_diff(a, b)
with pytest.raises(ValueError):
ulp_diff(b, a)


@given(float_pairs(max_ulps=100))
@example((-5e-324, 5e-324, 2))
def test_compare_ulp(args: tuple[float, float, int]) -> None:
a, b, ulps = args
assert compare_ulp(a, b, abs(ulps))
if abs(ulps) > 0:
assert not compare_ulp(a, b, abs(ulps) - 1)
assert compare_ulp(a, b, abs(ulps) + 1)


def test_compare_ulp_errors() -> None:
with pytest.raises(ValueError):
compare_ulp(1.0, 1.0, -1)


@given(st.floats())
def test_FloatInspector(x: float) -> None:
fi = FloatInspector(x)
note(f"fi={fi} (subnormal: {fi.is_subnormal()})")
if not math.isnan(x):
assert float(fi) == x
# convert to an int to check raw NaN representation
assert float_to_int(float(fi)) == float_to_int(x)

# try manually reconstructing the value from the mantissa and exponent
if math.isfinite(x):
reconstructed = math.ldexp(fi.mantissa, fi.exponent)
assert reconstructed == x

# check interrogator methods
# need to use copysign rather than `x < 0.0` to handle -0.0 properly
is_negative = math.copysign(1.0, x) == -1.0
assert fi.is_negative() == is_negative
assert fi.is_inf() == math.isinf(x)
assert fi.is_nan() == math.isnan(x)
is_subnormal = 0 < abs(x) < sys.float_info.min
assert fi.is_subnormal() == is_subnormal

0 comments on commit 48d29d9

Please sign in to comment.