diff --git a/src/gfloat/__init__.py b/src/gfloat/__init__.py index 4ca1ed4..b47e504 100644 --- a/src/gfloat/__init__.py +++ b/src/gfloat/__init__.py @@ -9,8 +9,10 @@ ) from .decode import decode_float from .printing import float_pow2str, float_tilde_unless_roundtrip_str -from .round import encode_float, round_float -from .round_ndarray import encode_ndarray, round_ndarray +from .round import round_float +from .encode import encode_float +from .round_ndarray import round_ndarray +from .encode_ndarray import encode_ndarray from .decode_ndarray import decode_ndarray from .types import FloatClass, FloatValue, FormatInfo, RoundMode diff --git a/src/gfloat/block.py b/src/gfloat/block.py index f04208f..14c4f02 100644 --- a/src/gfloat/block.py +++ b/src/gfloat/block.py @@ -10,7 +10,8 @@ import numpy.typing as npt from .decode import decode_float -from .round import RoundMode, encode_float, round_float +from .round import RoundMode, round_float +from .encode import encode_float from .types import FormatInfo diff --git a/src/gfloat/encode.py b/src/gfloat/encode.py new file mode 100644 index 0000000..2b71187 --- /dev/null +++ b/src/gfloat/encode.py @@ -0,0 +1,98 @@ +# Copyright (c) 2024 Graphcore Ltd. All rights reserved. + +import math + +import numpy as np + +from .types import FormatInfo + + +def encode_float(fi: FormatInfo, v: float) -> int: + """ + Encode input to the given :py:class:`FormatInfo`. + + Will round toward zero if :paramref:`v` is not in the value set. + Will saturate to `Inf`, `NaN`, `fi.max` in order of precedence. + Encode -0 to 0 if not `fi.has_nz` + + For other roundings and saturations, call :func:`round_float` first. + + Args: + fi (FormatInfo): Describes the target format + v (float): The value to be encoded. + + Returns: + The integer code point + """ + + # Format Constants + k = fi.bits + p = fi.precision + t = p - 1 + + # Encode + if np.isnan(v): + return fi.code_of_nan + + # Overflow/underflow + if v > fi.max: + if fi.has_infs: + return fi.code_of_posinf + if fi.num_nans > 0: + return fi.code_of_nan + return fi.code_of_max + + if v < fi.min: + if fi.has_infs: + return fi.code_of_neginf + if fi.num_nans > 0: + return fi.code_of_nan + return fi.code_of_min + + # Finite values + sign = fi.is_signed and np.signbit(v) + vpos = -v if sign else v + + if fi.has_subnormals and vpos <= fi.smallest_subnormal / 2: + isig = 0 + biased_exp = 0 + else: + sig, exp = np.frexp(vpos) + exp = int(exp) # All calculations in Python ints + + # sig in range [0.5, 1) + sig *= 2 + exp -= 1 + # now sig in range [1, 2) + + biased_exp = exp + fi.expBias + if biased_exp < 1 and fi.has_subnormals: + # subnormal + sig *= 2.0 ** (biased_exp - 1) + biased_exp = 0 + assert vpos == sig * 2 ** (1 - fi.expBias) + else: + if sig > 0: + sig -= 1.0 + + isig = math.floor(sig * 2**t) + + # Zero + if isig == 0 and biased_exp == 0 and fi.has_zero: + if sign and fi.has_nz: + return fi.code_of_negzero + else: + return fi.code_of_zero + + # Nonzero + assert isig < 2**t + assert biased_exp < 2**fi.expBits or fi.is_twos_complement + + # Handle two's complement encoding + if fi.is_twos_complement and sign: + isig = (1 << t) - isig + + # Pack values into a single integer + code = (int(sign) << (k - 1)) | (biased_exp << t) | (isig << 0) + + return code diff --git a/src/gfloat/encode_ndarray.py b/src/gfloat/encode_ndarray.py new file mode 100644 index 0000000..e72309b --- /dev/null +++ b/src/gfloat/encode_ndarray.py @@ -0,0 +1,84 @@ +# Copyright (c) 2024 Graphcore Ltd. All rights reserved. + +from .types import FormatInfo +import numpy as np + + +def encode_ndarray(fi: FormatInfo, v: np.ndarray) -> np.ndarray: + """ + Vectorized version of :meth:`encode_float`. + + Encode inputs to the given :py:class:`FormatInfo`. + + Will round toward zero if :paramref:`v` is not in the value set. + Will saturate to `Inf`, `NaN`, `fi.max` in order of precedence. + Encode -0 to 0 if not `fi.has_nz` + + For other roundings and saturations, call :func:`round_ndarray` first. + + Args: + fi (FormatInfo): Describes the target format + v (float array): The value to be encoded. + + Returns: + The integer code point + """ + k = fi.bits + p = fi.precision + t = p - 1 + + sign = np.signbit(v) & fi.is_signed + vpos = np.where(sign, -v, v) + + nan_mask = np.isnan(v) + + code = np.zeros_like(v, dtype=np.uint64) + + if fi.num_nans > 0: + code[nan_mask] = fi.code_of_nan + else: + assert not np.any(nan_mask) + + if fi.has_infs: + code[v > fi.max] = fi.code_of_posinf + code[v < fi.min] = fi.code_of_neginf + else: + code[v > fi.max] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_max + code[v < fi.min] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_min + + if fi.has_zero: + if fi.has_nz: + code[v == 0] = np.where(sign[v == 0], fi.code_of_negzero, fi.code_of_zero) + else: + code[v == 0] = fi.code_of_zero + + finite_mask = (code == 0) & (v != 0) + assert not np.any(np.isnan(vpos[finite_mask])) + if np.any(finite_mask): + finite_vpos = vpos[finite_mask] + finite_sign = sign[finite_mask] + + sig, exp = np.frexp(finite_vpos) + + biased_exp = exp.astype(np.int64) + (fi.expBias - 1) + subnormal_mask = (biased_exp < 1) & fi.has_subnormals + + biased_exp_safe = np.where(subnormal_mask, biased_exp, 0) + tsig = np.where(subnormal_mask, np.ldexp(sig, biased_exp_safe), sig * 2 - 1.0) + biased_exp[subnormal_mask] = 0 + + isig = np.floor(np.ldexp(tsig, t)).astype(np.int64) + + zero_mask = fi.has_zero & (isig == 0) & (biased_exp == 0) + if not fi.has_nz: + finite_sign[zero_mask] = False + + # Handle two's complement encoding + if fi.is_twos_complement: + isig[finite_sign] = (1 << t) - isig[finite_sign] + + code[finite_mask] = ( + (finite_sign.astype(int) << (k - 1)) | (biased_exp << t) | (isig << 0) + ) + + return code diff --git a/src/gfloat/round.py b/src/gfloat/round.py index 5a9fc5a..2edaf79 100644 --- a/src/gfloat/round.py +++ b/src/gfloat/round.py @@ -166,94 +166,3 @@ def round_float( result = -result return result - - -def encode_float(fi: FormatInfo, v: float) -> int: - """ - Encode input to the given :py:class:`FormatInfo`. - - Will round toward zero if :paramref:`v` is not in the value set. - Will saturate to `Inf`, `NaN`, `fi.max` in order of precedence. - Encode -0 to 0 if not `fi.has_nz` - - For other roundings and saturations, call :func:`round_float` first. - - Args: - fi (FormatInfo): Describes the target format - v (float): The value to be encoded. - - Returns: - The integer code point - """ - - # Format Constants - k = fi.bits - p = fi.precision - t = p - 1 - - # Encode - if np.isnan(v): - return fi.code_of_nan - - # Overflow/underflow - if v > fi.max: - if fi.has_infs: - return fi.code_of_posinf - if fi.num_nans > 0: - return fi.code_of_nan - return fi.code_of_max - - if v < fi.min: - if fi.has_infs: - return fi.code_of_neginf - if fi.num_nans > 0: - return fi.code_of_nan - return fi.code_of_min - - # Finite values - sign = fi.is_signed and np.signbit(v) - vpos = -v if sign else v - - if fi.has_subnormals and vpos <= fi.smallest_subnormal / 2: - isig = 0 - biased_exp = 0 - else: - sig, exp = np.frexp(vpos) - exp = int(exp) # All calculations in Python ints - - # sig in range [0.5, 1) - sig *= 2 - exp -= 1 - # now sig in range [1, 2) - - biased_exp = exp + fi.expBias - if biased_exp < 1 and fi.has_subnormals: - # subnormal - sig *= 2.0 ** (biased_exp - 1) - biased_exp = 0 - assert vpos == sig * 2 ** (1 - fi.expBias) - else: - if sig > 0: - sig -= 1.0 - - isig = math.floor(sig * 2**t) - - # Zero - if isig == 0 and biased_exp == 0 and fi.has_zero: - if sign and fi.has_nz: - return fi.code_of_negzero - else: - return fi.code_of_zero - - # Nonzero - assert isig < 2**t - assert biased_exp < 2**fi.expBits or fi.is_twos_complement - - # Handle two's complement encoding - if fi.is_twos_complement and sign: - isig = (1 << t) - isig - - # Pack values into a single integer - code = (int(sign) << (k - 1)) | (biased_exp << t) | (isig << 0) - - return code diff --git a/src/gfloat/round_ndarray.py b/src/gfloat/round_ndarray.py index 4936525..e5de6f2 100644 --- a/src/gfloat/round_ndarray.py +++ b/src/gfloat/round_ndarray.py @@ -150,83 +150,3 @@ def round_ndarray( result = np.where(result == 0, 0.0, result) return result - - -def encode_ndarray(fi: FormatInfo, v: np.ndarray) -> np.ndarray: - """ - Vectorized version of :meth:`encode_float`. - - Encode inputs to the given :py:class:`FormatInfo`. - - Will round toward zero if :paramref:`v` is not in the value set. - Will saturate to `Inf`, `NaN`, `fi.max` in order of precedence. - Encode -0 to 0 if not `fi.has_nz` - - For other roundings and saturations, call :func:`round_ndarray` first. - - Args: - fi (FormatInfo): Describes the target format - v (float array): The value to be encoded. - - Returns: - The integer code point - """ - k = fi.bits - p = fi.precision - t = p - 1 - - sign = np.signbit(v) & fi.is_signed - vpos = np.where(sign, -v, v) - - nan_mask = np.isnan(v) - - code = np.zeros_like(v, dtype=np.uint64) - - if fi.num_nans > 0: - code[nan_mask] = fi.code_of_nan - else: - assert not np.any(nan_mask) - - if fi.has_infs: - code[v > fi.max] = fi.code_of_posinf - code[v < fi.min] = fi.code_of_neginf - else: - code[v > fi.max] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_max - code[v < fi.min] = fi.code_of_nan if fi.num_nans > 0 else fi.code_of_min - - if fi.has_zero: - if fi.has_nz: - code[v == 0] = np.where(sign[v == 0], fi.code_of_negzero, fi.code_of_zero) - else: - code[v == 0] = fi.code_of_zero - - finite_mask = (code == 0) & (v != 0) - assert not np.any(np.isnan(vpos[finite_mask])) - if np.any(finite_mask): - finite_vpos = vpos[finite_mask] - finite_sign = sign[finite_mask] - - sig, exp = np.frexp(finite_vpos) - - biased_exp = exp.astype(np.int64) + (fi.expBias - 1) - subnormal_mask = (biased_exp < 1) & fi.has_subnormals - - biased_exp_safe = np.where(subnormal_mask, biased_exp, 0) - tsig = np.where(subnormal_mask, np.ldexp(sig, biased_exp_safe), sig * 2 - 1.0) - biased_exp[subnormal_mask] = 0 - - isig = np.floor(np.ldexp(tsig, t)).astype(np.int64) - - zero_mask = fi.has_zero & (isig == 0) & (biased_exp == 0) - if not fi.has_nz: - finite_sign[zero_mask] = False - - # Handle two's complement encoding - if fi.is_twos_complement: - isig[finite_sign] = (1 << t) - isig[finite_sign] - - code[finite_mask] = ( - (finite_sign.astype(int) << (k - 1)) | (biased_exp << t) | (isig << 0) - ) - - return code