Skip to content

Commit

Permalink
Add-based SR (#44)
Browse files Browse the repository at this point in the history
Previous implementation was based on comparison: x - floor(x) > rand
This changes to be based on add: x + rand >= floor(x) + 1

No behavioural change expected, but bias direction and rounding profiles are inverted
  • Loading branch information
awf authored Jan 13, 2025
1 parent 465ee41 commit 279e4f1
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 93 deletions.
6 changes: 2 additions & 4 deletions docs/source/03-value-tables.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,7 @@
" 0xFF,\n",
" ):\n",
" print(\n",
" str_tablerow(\n",
" fi, decode_float(fi, i), show_b16_info=True, vs_width=8, vs_d=4\n",
" )\n",
" str_tablerow(fi, decode_float(fi, i), show_b16_info=True, vs_width=8, vs_d=4)\n",
" )"
]
},
Expand Down Expand Up @@ -3266,7 +3264,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "gfloat-clean",
"language": "python",
"name": "python3"
},
Expand Down
203 changes: 131 additions & 72 deletions docs/source/05-stochastic-rounding.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = {file = ["requirements.txt"]}
optional-dependencies = {dev = {file = ["requirements-dev.txt"]}}

[tool.black]
line-length = 88
line-length = 90
fast = true

[tool.mypy]
Expand Down
4 changes: 1 addition & 3 deletions src/gfloat/decode_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from .types import FormatInfo


def decode_ndarray(
fi: FormatInfo, codes: np.ndarray, np: ModuleType = np
) -> np.ndarray:
def decode_ndarray(fi: FormatInfo, codes: np.ndarray, np: ModuleType = np) -> np.ndarray:
r"""
Vectorized version of :meth:`decode_float`
Expand Down
12 changes: 6 additions & 6 deletions src/gfloat/round.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def round_float(
case RoundMode.TowardNegative:
should_round_away = sign and delta > 0
case RoundMode.TiesToAway:
should_round_away = delta >= 0.5
should_round_away = delta + 0.5 >= 1.0
case RoundMode.TiesToEven:
should_round_away = delta > 0.5 or (delta == 0.5 and code_is_odd)
case RoundMode.Stochastic:
Expand All @@ -113,20 +113,20 @@ def round_float(
(d - floord > 0.5) or ((d - floord == 0.5) and _isodd(floord))
)

should_round_away = d > srbits
should_round_away = d + srbits >= 2.0**srnumbits
case RoundMode.StochasticOdd:
## RTNE delta to srbits
d = delta * 2.0**srnumbits
floord = np.floor(d).astype(np.int64)
d = floord + (
(d - floord > 0.5) or ((d - floord == 0.5) and ~_isodd(floord))
(d - floord > 0.5) or ((d - floord == 0.5) and not _isodd(floord))
)

should_round_away = d > srbits
should_round_away = d + srbits >= 2.0**srnumbits
case RoundMode.StochasticFast:
should_round_away = delta > (0.5 + srbits) * 2.0**-srnumbits
should_round_away = delta + (0.5 + srbits) * 2.0**-srnumbits >= 1.0
case RoundMode.StochasticFastest:
should_round_away = delta > srbits * 2.0**-srnumbits
should_round_away = delta + srbits * 2.0**-srnumbits >= 1.0

if should_round_away:
# This may increase isignificand to 2**p,
Expand Down
16 changes: 12 additions & 4 deletions src/gfloat/round_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,19 @@ def round_ndarray(
match rnd:
case RoundMode.TowardZero:
should_round_away = np.zeros_like(delta, dtype=bool)

case RoundMode.TowardPositive:
should_round_away = ~is_negative & (delta > 0)

case RoundMode.TowardNegative:
should_round_away = is_negative & (delta > 0)

case RoundMode.TiesToAway:
should_round_away = delta >= 0.5

case RoundMode.TiesToEven:
should_round_away = (delta > 0.5) | ((delta == 0.5) & code_is_odd)

case RoundMode.Stochastic:
assert srbits is not None
## RTNE delta to srbits
Expand All @@ -94,7 +99,8 @@ def round_ndarray(
dd = d - floord
drnd = floord + (dd > 0.5) + ((dd == 0.5) & _isodd(floord))

should_round_away = drnd > srbits
should_round_away = drnd + srbits >= 2.0**srnumbits

case RoundMode.StochasticOdd:
assert srbits is not None
## RTNO delta to srbits
Expand All @@ -103,13 +109,15 @@ def round_ndarray(
dd = d - floord
drnd = floord + (dd > 0.5) + ((dd == 0.5) & ~_isodd(floord))

should_round_away = drnd > srbits
should_round_away = drnd + srbits >= 2.0**srnumbits

case RoundMode.StochasticFast:
assert srbits is not None
should_round_away = delta > (2 * srbits + 1) * 2.0 ** -(1 + srnumbits)
should_round_away = delta + (2 * srbits + 1) * 2.0 ** -(1 + srnumbits) >= 1.0

case RoundMode.StochasticFastest:
assert srbits is not None
should_round_away = delta > srbits * 2.0**-srnumbits
should_round_away = delta + srbits * 2.0**-srnumbits >= 1.0

isignificand = np.where(should_round_away, isignificand + 1, isignificand)

Expand Down
4 changes: 1 addition & 3 deletions test/test_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,7 @@ def test_p3109_k8_specials(fi: FormatInfo) -> None:
assert fi.code_of_neginf == 0xFF


@pytest.mark.parametrize(
"k,p", [(8, 3), (8, 1), (6, 1), (6, 5), (3, 1), (3, 2), (11, 3)]
)
@pytest.mark.parametrize("k,p", [(8, 3), (8, 1), (6, 1), (6, 5), (3, 1), (3, 2), (11, 3)])
def test_p3109_specials(k: int, p: int) -> None:
fi = format_info_p3109(k, p)
assert fi.code_of_nan == 2 ** (k - 1)
Expand Down

0 comments on commit 279e4f1

Please sign in to comment.