diff --git a/wolfssl-sys/build.rs b/wolfssl-sys/build.rs index 393fb9b..55cb4a3 100644 --- a/wolfssl-sys/build.rs +++ b/wolfssl-sys/build.rs @@ -50,6 +50,7 @@ const PATCHES: &[&str] = &[ "fix-kyber-mlkem-benchmark.patch", "fix-mlkem-get-curve-name.patch", "fix-kyber-get-curve-name.patch", + "fix-kyber-prf-non-avx2.patch", ]; /** diff --git a/wolfssl-sys/patches/fix-kyber-prf-non-avx2.patch b/wolfssl-sys/patches/fix-kyber-prf-non-avx2.patch new file mode 100644 index 0000000..2588c18 --- /dev/null +++ b/wolfssl-sys/patches/fix-kyber-prf-non-avx2.patch @@ -0,0 +1,51 @@ +From e507c466d5b97f1b1b063783f4ae020c38ebae4b Mon Sep 17 00:00:00 2001 +From: Sean Parkinson +Date: Fri, 20 Dec 2024 11:03:58 +1000 +Subject: [PATCH] ML-KEM/Kyber: fix kyber_prf() for when no AVX2 + +When no AVX2 available, kyber_prf() is called to produce more than one +SHAKE-256 blocks worth of ouput. Otherwise only one block is needed. +Changed function to support an outlen of greater than one block. +--- + wolfcrypt/src/wc_kyber_poly.c | 27 +++++++++++++++++---------- + 1 file changed, 17 insertions(+), 10 deletions(-) + +diff --git a/wolfcrypt/src/wc_kyber_poly.c b/wolfcrypt/src/wc_kyber_poly.c +index d947d37e95..76b5cd5d77 100644 +--- a/wolfcrypt/src/wc_kyber_poly.c ++++ b/wolfcrypt/src/wc_kyber_poly.c +@@ -2074,17 +2074,24 @@ static int kyber_prf(wc_Shake* shake256, byte* out, unsigned int outLen, + (25 - KYBER_SYM_SZ / 8 - 1) * sizeof(word64)); + state[WC_SHA3_256_COUNT - 1] = W64LIT(0x8000000000000000); + +- if (IS_INTEL_BMI2(cpuid_flags)) { +- sha3_block_bmi2(state); +- } +- else if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) { +- sha3_block_avx2(state); +- RESTORE_VECTOR_REGISTERS(); +- } +- else { +- BlockSha3(state); ++ while (outLen > 0) { ++ unsigned int len = min(outLen, WC_SHA3_256_BLOCK_SIZE); ++ ++ if (IS_INTEL_BMI2(cpuid_flags)) { ++ sha3_block_bmi2(state); ++ } ++ else if (IS_INTEL_AVX2(cpuid_flags) && ++ (SAVE_VECTOR_REGISTERS2() == 0)) { ++ sha3_block_avx2(state); ++ RESTORE_VECTOR_REGISTERS(); ++ } ++ else { ++ BlockSha3(state); ++ } ++ XMEMCPY(out, state, len); ++ out += len; ++ outLen -= len; + } +- XMEMCPY(out, state, outLen); + + return 0; + #else