Skip to content

Commit

Permalink
update to add default seed api
Browse files Browse the repository at this point in the history
todo: update docs
  • Loading branch information
supinie committed Jan 8, 2025
1 parent 9e1eeab commit 7f5ec65
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 51 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ license = "GPL-3.0-or-later"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[features]
default = []
decap_key = [] # Use the true key instead of seed for PrivateKey. Default uses seed.

[profile.release]
opt-level = "s"
lto = false
Expand Down
11 changes: 9 additions & 2 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub enum CrystalsError {
InvalidSeedLength(usize, usize),
InternalError(),
InvalidK(usize),
InvalidCiphertextLength(usize, usize, K),
InvalidCiphertextLength(usize),
}

impl Display for CrystalsError {
Expand All @@ -24,7 +24,7 @@ impl Display for CrystalsError {
Self::InvalidSeedLength(seed_len, expected_seed_len) => write!(f, "Invalid seed length, expected {expected_seed_len}, got {seed_len}"),
Self::InternalError() => write!(f, "Unexpected internal error"),
Self::InvalidK(k) => write!(f, "Recieved invalid k value, {k}, expected 2, 3, or 4"),
Self::InvalidCiphertextLength(ciphertext_len, expected_ciphertext_len, sec_level) => write!(f, "Invalid ciphertext length, expected {expected_ciphertext_len}, got {ciphertext_len} (key security level: {sec_level})"),
Self::InvalidCiphertextLength(ciphertext_len) => write!(f, "Invalid ciphertext length, expected 768, 1088, or 1568, got {ciphertext_len}"),
}
}
}
Expand Down Expand Up @@ -96,6 +96,7 @@ impl From<rand_core::Error> for KeyGenerationError {
#[derive(Debug)]
pub enum EncryptionDecryptionError {
Crystals(CrystalsError),
KeyGenerationError(KeyGenerationError),
TryFromInt(TryFromIntError),
Packing(PackingError),
Rand(rand_core::Error),
Expand All @@ -107,6 +108,12 @@ impl From<CrystalsError> for EncryptionDecryptionError {
}
}

impl From<KeyGenerationError> for EncryptionDecryptionError {
fn from(error: KeyGenerationError) -> Self {
Self::KeyGenerationError(error)
}
}

impl From<TryFromIntError> for EncryptionDecryptionError {
fn from(error: TryFromIntError) -> Self {
Self::TryFromInt(error)
Expand Down
165 changes: 121 additions & 44 deletions src/kem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,25 @@ use tinyvec::ArrayVec;
/// and is used to [`decapsulate`](PrivateKey::decapsulate) a shared secret from a given ciphertext.
///
/// Can be accessed in byte form by packing into a `u8` array using the [`pack`](PrivateKey::pack) method,
/// and made available for use again using the [`unpack`](PrivateKey::unpack) method. The
/// and made available for use again using the [`unpack`](PrivateKey::unpack) method. If using
/// "decap_key" feature, the
/// array used to pack must be of the correct length for the given security level, see
/// [`pack`](PrivateKey::pack) for more.
#[derive(Debug, Eq, PartialEq)]
pub struct PrivateKey {
#[cfg(not(feature = "decap_key"))]
key: PrivateSeed,
#[cfg(feature = "decap_key")]
key: PrivateKeyInner,
}

#[derive(Debug, Eq, PartialEq)]
struct PrivateSeed {
seed: [u8; 2 * SYMBYTES],
}

#[derive(Debug, Eq, PartialEq)]
struct PrivateKeyInner {
sk: IndcpaPrivateKey,
pk: IndcpaPublicKey,
h_pk: [u8; SYMBYTES],
Expand Down Expand Up @@ -100,7 +114,7 @@ fn shake256_from(input: &[u8]) -> [u8; SHAREDSECRETBYTES] {
fn new_key_from_seed(
seed: &[u8],
sec_level: SecurityLevel,
) -> Result<(PublicKey, PrivateKey), KeyGenerationError> {
) -> Result<(PublicKey, PrivateKeyInner), KeyGenerationError> {
if seed.len() != 2 * SYMBYTES {
return Err(CrystalsError::InvalidSeedLength(seed.len(), 2 * SYMBYTES).into());
}
Expand All @@ -114,7 +128,7 @@ fn new_key_from_seed(

let h_pk: [u8; SYMBYTES] = sha3_256_from(&packed_pk[..sec_level.indcpa_public_key_bytes()]);

Ok((PublicKey { pk, h_pk }, PrivateKey { sk, pk, h_pk, z }))
Ok((PublicKey { pk, h_pk }, PrivateKeyInner { sk, pk, h_pk, z }))
}

/// Acceptable RNG to be used in encapsulation and key generation must have the
Expand All @@ -137,7 +151,17 @@ pub(crate) fn generate_key_pair(

let sec_level = SecurityLevel::new(k);

new_key_from_seed(&seed, sec_level)
let (pk, sk_inner) = new_key_from_seed(&seed, sec_level)?;

Ok((
pk,
PrivateKey {
#[cfg(not(feature = "decap_key"))]
key: PrivateSeed { seed },
#[cfg(feature = "decap_key")]
key: sk_inner,
},
))
}

/// Generates a new keypair for the 512 Security Parameters.
Expand Down Expand Up @@ -237,27 +261,37 @@ pub fn generate_keypair_1024(
}

impl PrivateKey {
#[cfg(feature = "decap_key")]
pub(crate) const fn sec_level(&self) -> SecurityLevel {
self.sk.sec_level()
self.key.sk.sec_level()
}

/// Returns the corresponding public key for a given private key
///
/// # Example
/// ```
/// # use enc_rust::kem::*;
/// let (_, sk) = generate_keypair_768(None)?;
/// let pk = sk.get_public_key();
///
/// # Ok::<(), enc_rust::errors::KeyGenerationError>(())
/// ```
#[must_use]
pub const fn get_public_key(&self) -> PublicKey {
PublicKey {
pk: self.pk,
h_pk: self.h_pk,
}
}
// /// Returns the corresponding public key for a given private key
// ///
// /// # Example
// /// ```
// /// # use enc_rust::kem::*;
// /// let (_, sk) = generate_keypair_768(None)?;
// /// let pk = sk.get_public_key();
// ///
// /// # Ok::<(), enc_rust::errors::KeyGenerationError>(())
// /// ```
// #[must_use]
// pub const fn get_public_key(&self) -> PublicKey {
// #[cfg(not(feature = "decap_key"))]
// {
// let (pk, _) = new_key_from_seed(&self.key.seed, self.key.sec_level).unwrap();

// pk
// }
// #[cfg(feature = "decap_key")]
// {
// PublicKey {
// pk: self.key.pk,
// h_pk: self.key.h_pk,
// }
// }
// }

/// Packs private key into a given buffer
///
Expand All @@ -278,11 +312,21 @@ impl PrivateKey {
/// ```
/// # use enc_rust::kem::*;
/// let (_, sk) = generate_keypair_768(None).unwrap();
/// let mut sk_bytes = [0u8; 2400];
/// sk.pack(&mut sk_bytes)?;
/// #[cfg(feature = "decap_key")]
/// {
/// let mut sk_bytes = [0u8; 2400];
/// sk.pack(&mut sk_bytes)?;
/// }
/// #[cfg(not(feature = "decap_key"))]
/// let sk_bytes = sk.pack();
///
/// # Ok::<(), enc_rust::errors::PackingError>(())
/// ```
#[cfg(not(feature = "decap_key"))]
pub fn pack(&self) -> [u8; 2 * SYMBYTES] {
self.key.seed.clone()
}
#[cfg(feature = "decap_key")]
pub fn pack(&self, bytes: &mut [u8]) -> Result<(), PackingError> {
let sec_level = self.sec_level();

Expand All @@ -297,10 +341,10 @@ impl PrivateKey {
let (sk_bytes, rest) = bytes.split_at_mut(sec_level.indcpa_private_key_bytes());
let (pk_bytes, rest) = rest.split_at_mut(sec_level.indcpa_public_key_bytes());
let (h_pk_bytes, z_bytes) = rest.split_at_mut(SYMBYTES);
self.sk.pack(sk_bytes)?;
self.pk.pack(pk_bytes)?;
h_pk_bytes.copy_from_slice(&self.h_pk);
z_bytes.copy_from_slice(&self.z);
self.key.sk.pack(sk_bytes)?;
self.key.pk.pack(pk_bytes)?;
h_pk_bytes.copy_from_slice(&self.key.h_pk);
z_bytes.copy_from_slice(&self.key.z);

Ok(())
}
Expand All @@ -320,12 +364,29 @@ impl PrivateKey {
/// ```
/// # use enc_rust::kem::*;
/// # let (pk, new_sk) = generate_keypair_768(None).unwrap();
/// # let mut sk_bytes = [0u8; 2400];
/// # new_sk.pack(&mut sk_bytes)?;
/// # #[cfg(feature = "decap_key")]
/// # {
/// # let mut sk_bytes = [0u8; 2400];
/// # new_sk.pack(&mut sk_bytes)?;
/// # }
/// # #[cfg(not(feature = "decap_key"))]
/// # let sk_bytes = new_sk.pack();
///
/// #[cfg(feature = "decap_key")]
/// let sk = PrivateKey::unpack(&sk_bytes)?;
///
/// #[cfg(not(feature = "decap_key"))]
/// let sk = PrivateKey::unpack(sk_bytes);
///
/// # Ok::<(), enc_rust::errors::PackingError>(())
/// ```
#[cfg(not(feature = "decap_key"))]
pub fn unpack(bytes: [u8; 2 * SYMBYTES]) -> Self {
Self {
key: PrivateSeed { seed: bytes },
}
}
#[cfg(feature = "decap_key")]
pub fn unpack(bytes: &[u8]) -> Result<Self, PackingError> {
let sec_level = match bytes.len() {
1632 => SecurityLevel::new(K::Two),
Expand All @@ -344,7 +405,9 @@ impl PrivateKey {
let mut z = [0u8; SYMBYTES];
z.copy_from_slice(z_bytes);

Ok(Self { sk, pk, h_pk, z })
Ok(Self {
key: PrivateKeyInner { sk, pk, h_pk, z },
})
}

/// Decapsulates a ciphertext (given as a byte slice) into the shared secret
Expand Down Expand Up @@ -373,25 +436,39 @@ impl PrivateKey {
&self,
ciphertext: &[u8],
) -> Result<[u8; SHAREDSECRETBYTES], EncryptionDecryptionError> {
let sec_level = self.sec_level();
let valid_bytes = [
SecurityLevel::new(K::Two).ciphertext_bytes(),
SecurityLevel::new(K::Three).ciphertext_bytes(),
SecurityLevel::new(K::Four).ciphertext_bytes(),
];

let sec_level = match ciphertext.len() {
len if len == valid_bytes[0] => {
Ok::<SecurityLevel, CrystalsError>(SecurityLevel::new(K::Two))
}
len if len == valid_bytes[1] => {
Ok::<SecurityLevel, CrystalsError>(SecurityLevel::new(K::Three))
}
len if len == valid_bytes[2] => {
Ok::<SecurityLevel, CrystalsError>(SecurityLevel::new(K::Four))
}
_ => Err(CrystalsError::InvalidCiphertextLength(ciphertext.len()).into()),
}?;

if ciphertext.len() != sec_level.ciphertext_bytes() {
return Err(CrystalsError::InvalidCiphertextLength(
ciphertext.len(),
sec_level.ciphertext_bytes(),
sec_level.k(),
)
.into());
}
#[cfg(not(feature = "decap_key"))]
let (_, inner) = new_key_from_seed(&self.key.seed, sec_level)?;
#[cfg(feature = "decap_key")]
let inner = &self.key;

let m = self.sk.decrypt(ciphertext)?;
let m = inner.sk.decrypt(ciphertext)?;

let (k, r) = sha3_512_from(&[m, self.h_pk].concat());
let (k, r) = sha3_512_from(&[m, inner.h_pk].concat());

let k_bar = shake256_from(&[&self.z, ciphertext].concat());
let k_bar = shake256_from(&[&inner.z, ciphertext].concat());

let mut ct = [0u8; MAX_CIPHERTEXT]; // max indcpa_bytes()
self.pk
inner
.pk
.encrypt(&m, &r, &mut ct[..sec_level.indcpa_bytes()])?;

let equal = ct.ct_eq(ciphertext);
Expand Down
19 changes: 14 additions & 5 deletions src/tests/kem.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,23 @@ mod kem_tests {
let mut pk_bytes = [0u8; 1568];
pk.pack(&mut pk_bytes[..pk.sec_level().public_key_bytes()]);
let unpacked_pk = PublicKey::unpack(&pk_bytes[..pk.sec_level().public_key_bytes()]).unwrap();
assert_eq!(pk, unpacked_pk);

#[cfg(feature = "decap_key")]
{
let mut sk_bytes = [0u8; 3168];
sk.pack(&mut sk_bytes[..pk.sec_level().private_key_bytes()]);
let unpacked_sk = PrivateKey::unpack(&sk_bytes[..sk.sec_level().private_key_bytes()]).unwrap();
assert_eq!(sk, unpacked_sk);
}
#[cfg(not(feature = "decap_key"))]
{
let sk_bytes = sk.pack();
let unpacked_sk = PrivateKey::unpack(sk_bytes);

let mut sk_bytes = [0u8; 3168];
sk.pack(&mut sk_bytes[..pk.sec_level().private_key_bytes()]);
let unpacked_sk = PrivateKey::unpack(&sk_bytes[..sk.sec_level().private_key_bytes()]).unwrap();
assert_eq!(sk, unpacked_sk);
}

assert_eq!(pk, unpacked_pk);
assert_eq!(sk, unpacked_sk);
}
}
}

0 comments on commit 7f5ec65

Please sign in to comment.