Skip to content

Commit

Permalink
Merge pull request #3 from supinie/proptest
Browse files Browse the repository at this point in the history
added Proptest
  • Loading branch information
supinie authored Feb 8, 2024
2 parents aaaa6a3 + 29643f5 commit 554dd25
Show file tree
Hide file tree
Showing 18 changed files with 700 additions and 287 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,6 @@ Cargo.lock


*.swp

# Proptest
proptest-regressions
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,15 @@ byteorder = "1.4.3"
# clap = { version = "4.3.12", features = ["cargo"] }
more-asserts = "0.3.1"
num_enum = { version = "0.7.1", default-features = false }
rand_core = { version = "0.6.4", default-features = false }
sha3 = "0.10.8"
tinyvec = "1.6.0"
zeroize = { version = "1.7.0", default-features = false }

[dev-dependencies]
rand = "0.8.5"
proptest = "1.4.0"


# [workspace]
# members = ["src/kyber"]
Expand Down
23 changes: 11 additions & 12 deletions src/field_operations.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use more_asserts::assert_ge;

use crate::params::Q;

// given -2^15 q <= x < 2^15 q, returns -q < y < q with y = x 2^-16 mod q
// Example:
// let x = montgomery_reduce(y);
// ```
// let x = montgomery_reduce(5); //TODO: Tris broke this to remind you to make doctests!
// ```
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
pub fn montgomery_reduce(x: i32) -> i16 {
const QPRIME: i32 = 62209;
Expand Down Expand Up @@ -50,14 +50,13 @@ pub fn barrett_reduce(x: i16) -> i16 {
// Example:
// let x = conditional_sub_q(y);
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
pub fn conditional_sub_q(x: i16) -> i16 {
pub const fn conditional_sub_q(x: i16) -> i16 {
const Q_16: i16 = Q as i16;
assert_ge!(
x,
-29439,
"x must be >= to -29439 when applying conditional subtract q"
);
let mut result = x - Q_16;
result += (result >> 15) & Q_16;
result
if x < Q_16 {
x
} else {
let mut result = x - Q_16;
result += (result >> 15) & Q_16;
result
}
}
35 changes: 18 additions & 17 deletions src/indcpa.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use core::num::TryFromIntError;

use crate::{
matrix::{MatOperations, New},
params::{GetSecLevel, POLYBYTES},
params::{GetSecLevel, POLYBYTES, SYMBYTES},
polynomials::Poly,
vectors::{LinkSecLevel, PolyVecOperations},
};
Expand All @@ -15,7 +15,7 @@ pub struct PrivateKey<PV: PolyVecOperations> {

#[derive(Default, PartialEq, Debug, Eq)]
pub struct PublicKey<PV: PolyVecOperations, M: MatOperations + LinkSecLevel<PV>> {
pub rho: [u8; 32],
pub rho: [u8; SYMBYTES],
pub noise: PV,
pub a_t: M,
}
Expand Down Expand Up @@ -55,15 +55,15 @@ where
M: MatOperations + GetSecLevel + LinkSecLevel<PV> + New + IntoIterator<Item = PV> + Copy,
{
let mut pub_key = PublicKey {
rho: [0u8; 32],
rho: [0u8; SYMBYTES],
noise: PV::new_filled(),
a_t: M::new(),
};
let mut priv_key = PrivateKey {
secret: PV::new_filled(),
};

let mut expanded_seed = [0u8; 64];
let mut expanded_seed = [0u8; 2 * SYMBYTES];
let mut hash = Sha3_512::new();
hash.update(seed);

Expand Down Expand Up @@ -104,16 +104,19 @@ pub fn encrypt<PV, M>(
seed: &[u8],
// output_buf: &'a mut [u8],
output_buf: &mut [u8],
// ) -> Result<&'a [u8], TryFromIntError>
// ) -> Result<&'a [u8], TryFromIntError>
) -> Result<(), TryFromIntError>
where
PV: PolyVecOperations + GetSecLevel + Default + IntoIterator<Item = Poly> + Copy,
M: MatOperations + GetSecLevel + LinkSecLevel<PV> + New + IntoIterator<Item = PV> + Copy,
{
let mut m = Poly::new();
m.read_msg(plaintext)?;

let mut rh = PV::new_filled();
rh.derive_noise(seed, 0, PV::sec_level().eta_1());
rh.ntt();
rh.barrett_reduce();
// rh.barrett_reduce();

let k_value: u8 = PV::sec_level().k().into();
let mut error_1 = PV::new_filled();
Expand All @@ -125,27 +128,25 @@ where
for (mut poly, vec) in u.into_iter().zip(pub_key.a_t) {
poly.inner_product_pointwise(vec, rh);
}
u.barrett_reduce();
u.inv_ntt();

u.add(error_1);
u.barrett_reduce();

let mut v = Poly::new();
v.inner_product_pointwise(pub_key.noise, rh);
v.barrett_reduce();
// v.barrett_reduce();
v.inv_ntt();

let mut m = Poly::new();
m.read_msg(plaintext)?;

v.add(&m);
v.add(&error_2);

u.normalise();
v.normalise();
v.barrett_reduce();

// u.normalise();
// v.normalise();

let poly_vec_compressed_bytes: usize = PV::sec_level().poly_vec_compressed_bytes();
let poly_compressed_bytes: usize = PV::sec_level().poly_compressed_bytes();
u.compress(&mut output_buf[..poly_vec_compressed_bytes])?;
u.compress(output_buf)?;
v.compress(
&mut output_buf[poly_vec_compressed_bytes..],
&PV::sec_level(),
Expand All @@ -160,7 +161,7 @@ pub fn decrypt<PV>(
ciphertext: &[u8],
// output_buf: &'a mut [u8],
output_buf: &mut [u8],
// ) -> Result<&'a [u8], TryFromIntError>
// ) -> Result<&'a [u8], TryFromIntError>
) -> Result<(), TryFromIntError>
where
PV: PolyVecOperations + GetSecLevel + Default + IntoIterator<Item = Poly> + Copy,
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ mod vectors;
mod tests {
mod buffer;
mod field_operations;
mod indcpa;
mod matrix;
mod ntt;
mod params;
mod polynomials;
mod sample;
mod vectors;
mod indcpa;
}
4 changes: 2 additions & 2 deletions src/ntt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ impl Poly {
l <<= 1;
}

for j in 0..N {
self.coeffs[j] = montgomery_reduce(1441 * i32::from(self.coeffs[j]));
for coeff in &mut self.coeffs {
*coeff = montgomery_reduce(1441 * i32::from(*coeff));
}
}
}
Loading

0 comments on commit 554dd25

Please sign in to comment.