From ea4765cf25258c440c1628110d94857be352f025 Mon Sep 17 00:00:00 2001 From: Ruiyu Zhu Date: Wed, 4 May 2022 10:59:40 -0700 Subject: [PATCH] Fix potential overflow in Intp type (#197) Summary: Pull Request resolved: https://github.com/facebookresearch/fbpcf/pull/197 We use int32_t to store 32 bit signed integers. This might cause undesired behavior. To overcome this issue, we add some extra special treatment for these edge cases. Reviewed By: chualynn Differential Revision: D36106694 fbshipit-source-id: 978a021d46382857f59f43b454adeb4d7bdcbb40 --- fbpcf/mpc_std_lib/util/Intp_impl.h | 41 +++++++++++++++++++- fbpcf/mpc_std_lib/util/test/IntpTest.cpp | 49 ++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 fbpcf/mpc_std_lib/util/test/IntpTest.cpp diff --git a/fbpcf/mpc_std_lib/util/Intp_impl.h b/fbpcf/mpc_std_lib/util/Intp_impl.h index 4001ff21..f8aa4a89 100644 --- a/fbpcf/mpc_std_lib/util/Intp_impl.h +++ b/fbpcf/mpc_std_lib/util/Intp_impl.h @@ -7,6 +7,7 @@ #pragma once #include +#include #include #include #include @@ -54,7 +55,7 @@ class Intp { } Intp operator+(const Intp& other) const { - return Intp(round(v_ + other.v_)); + return Intp(add(v_, other.v_)); } Intp operator-() const { @@ -70,7 +71,7 @@ class Intp { } Intp operator-(const Intp& other) const { - return Intp(round(v_ - other.v_)); + return Intp(subtract(v_, other.v_)); } public: @@ -114,6 +115,42 @@ class Intp { } } + static NativeType add(NativeType a, NativeType b) { + if constexpr ( + isSigned && + ((width == 8) || (width == 16) || (width == 32) || (width == 64))) { + // special handling is needed only for signed integer with some special + // width (e.g. overflow is possible). + if (std::signbit(a) == std::signbit(b)) { + // the two numbers have the same sign, overflow is possible. + // special treatment to prevent overflow + return round(uint64_t(a) + uint64_t(b)); + } else { + return round(a + b); + } + } else { + return round(a + b); + } + } + + static NativeType subtract(NativeType a, NativeType b) { + if constexpr ( + isSigned && + ((width == 8) || (width == 16) || (width == 32) || (width == 64))) { + // special handling is needed only for signed integer with some special + // width (e.g. overflow is possible). + if (std::signbit(a) != std::signbit(b)) { + // the two numbers have different sign, overflow is possible. + // special treatment to prevent overflow + return round(uint64_t(a) - uint64_t(b)); + } else { + return round(a - b); + } + } else { + return round(a - b); + } + } + NativeType v_; }; diff --git a/fbpcf/mpc_std_lib/util/test/IntpTest.cpp b/fbpcf/mpc_std_lib/util/test/IntpTest.cpp new file mode 100644 index 00000000..0ec7034f --- /dev/null +++ b/fbpcf/mpc_std_lib/util/test/IntpTest.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +#include "fbpcf/mpc_std_lib/util/util.h" + +namespace fbpcf::mpc_std_lib::util { + +TEST(IntpTypeTest, testAdd) { + const int8_t width = 32; + int64_t largestSigned = std::numeric_limits().max(); + int64_t smallestSigned = std::numeric_limits().min(); + std::random_device rd; + std::mt19937_64 e(rd()); + std::uniform_int_distribution dist(smallestSigned, largestSigned); + for (int i = 0; i < 1000; i++) { + auto v1 = dist(e); + auto v2 = dist(e); + int32_t v = Intp(v1) + Intp(v2); + int32_t expectedV = (uint64_t)v1 + (uint64_t)v2; + EXPECT_EQ(v, expectedV); + } +} + +TEST(IntpTypeTest, testSubtract) { + const int8_t width = 32; + int64_t largestSigned = std::numeric_limits().max(); + int64_t smallestSigned = std::numeric_limits().min(); + std::random_device rd; + std::mt19937_64 e(rd()); + std::uniform_int_distribution dist(smallestSigned, largestSigned); + for (int i = 0; i < 1000; i++) { + auto v1 = dist(e); + auto v2 = dist(e); + int32_t v = Intp(v1) - Intp(v2); + int32_t expectedV = (uint64_t)v1 - (uint64_t)v2; + EXPECT_EQ(v, expectedV); + } +} + +} // namespace fbpcf::mpc_std_lib::util