diff --git a/hppk.go b/hppk.go index 91fd319..7d334b6 100644 --- a/hppk.go +++ b/hppk.go @@ -29,6 +29,10 @@ const ( ERR_MSG_INVALID_PRIME = "invalid prime number" ) +const ( + MULTIVARIATE = 5 +) + // PrivateKey represents a private key in the HPPK protocol. type PrivateKey struct { Prime *big.Int // Prime number used for cryptographic operations @@ -41,8 +45,8 @@ type PrivateKey struct { // PublicKey represents a public key in the HPPK protocol. type PublicKey struct { - P []*big.Int // Coefficients of the polynomial P(x) - Q []*big.Int // Coefficients of the polynomial Q(x) + P [][]*big.Int // Coefficient matrix of the polynomial P(x) + Q [][]*big.Int } // Signature represents a digital signature in the HPPK protocol. @@ -115,51 +119,60 @@ RETRY: goto RETRY } - // Generate random coefficients for the polynomial Bn(x) - Bn := make([]*big.Int, order) - for i := 0; i < len(Bn); i++ { - r, err := rand.Int(rand.Reader, prime) - if err != nil { - return nil, err - } - Bn[i] = r - } - Bn = append(Bn, big.NewInt(1)) - // Initialize P and Q with zero values - P := make([]*big.Int, len(Bn)+1) - Q := make([]*big.Int, len(Bn)+1) - for i := 0; i < len(P); i++ { - P[i] = big.NewInt(0) - Q[i] = big.NewInt(0) - } - - t := new(big.Int) - // Multiply f(x) and h(x) with Bn to get P and Q - for i := 0; i < len(Bn); i++ { - // Vector P - t.Mul(f0, Bn[i]) - P[i].Add(P[i], t) - P[i].Mod(P[i], prime) + Pm := make([][]*big.Int, MULTIVARIATE) + Qm := make([][]*big.Int, MULTIVARIATE) + + for c := 0; c < MULTIVARIATE; c++ { + // Generate random coefficients for the polynomial Bn(x) + Bn := make([]*big.Int, order) + for i := 0; i < len(Bn); i++ { + r, err := rand.Int(rand.Reader, prime) + if err != nil { + return nil, err + } + Bn[i] = r + } + Bn = append(Bn, big.NewInt(1)) - t.Mul(f1, Bn[i]) - P[i+1].Add(P[i+1], t) - P[i+1].Mod(P[i+1], prime) + P := make([]*big.Int, len(Bn)+1) + Q := make([]*big.Int, len(Bn)+1) + for i := 0; i < len(P); i++ { + P[i] = big.NewInt(0) + Q[i] = big.NewInt(0) + } - // Vector Q - t.Mul(h0, Bn[i]) - Q[i].Add(Q[i], t) - Q[i].Mod(Q[i], prime) + t := new(big.Int) + // Multiply f(x) and h(x) with Bn to get P and Q + for i := 0; i < len(Bn); i++ { + // Vector P + t.Mul(f0, Bn[i]) + P[i].Add(P[i], t) + P[i].Mod(P[i], prime) + + t.Mul(f1, Bn[i]) + P[i+1].Add(P[i+1], t) + P[i+1].Mod(P[i+1], prime) + + // Vector Q + t.Mul(h0, Bn[i]) + Q[i].Add(Q[i], t) + Q[i].Mod(Q[i], prime) + + t.Mul(h1, Bn[i]) + Q[i+1].Add(Q[i+1], t) + Q[i+1].Mod(Q[i+1], prime) + } - t.Mul(h1, Bn[i]) - Q[i+1].Add(Q[i+1], t) - Q[i+1].Mod(Q[i+1], prime) - } + // Convert P, Q to Ring S + for i := 0; i < len(P); i++ { + ring(r1, s1, P[i]) + ring(r2, s2, Q[i]) + } - // Convert P, Q to Ring S - for i := 0; i < len(P); i++ { - ring(r1, s1, P[i]) - ring(r2, s2, Q[i]) + // matrix + Pm[c] = P + Qm[c] = Q } // Return the generated private key @@ -174,8 +187,8 @@ RETRY: H0: h0, H1: h1, PublicKey: PublicKey{ - P: P, - Q: Q, + P: Pm, + Q: Qm, }, }, nil } @@ -222,30 +235,31 @@ func encrypt(pub *PublicKey, msg []byte, prime *big.Int) (kem *KEM, err error) { } } - // Generate a random noise - noise, err := rand.Int(rand.Reader, prime) - if err != nil { - return nil, err - } - noise = noise.Exp(big.NewInt(2), big.NewInt(1800), nil) - - // Initialize Si with the secret message - Si := big.NewInt(1) // Compute the encrypted values P and Q P := new(big.Int) Q := new(big.Int) - t := new(big.Int) - for i := 0; i < len(pub.P); i++ { - noised := new(big.Int).Mul(noise, Si) - noised.Mod(noised, prime) + for c := 0; c < len(pub.P); c++ { + // Generate a random noise + noise, err := rand.Int(rand.Reader, prime) + if err != nil { + return nil, err + } - P.Add(P, t.Mul(noised, pub.P[i])) - Q.Add(Q, t.Mul(noised, pub.Q[i])) + // Initialize Si with the secret message + Si := big.NewInt(1) + t := new(big.Int) + for i := 0; i < len(pub.P[0]); i++ { + noised := new(big.Int).Mul(noise, Si) + noised.Mod(noised, prime) - // Si = secret^i - Si.Mul(Si, secret) - Si.Mod(Si, prime) + P.Add(P, t.Mul(Si, pub.P[c][i])) + Q.Add(Q, t.Mul(Si, pub.Q[c][i])) + + // Si = secret^i + Si.Mul(Si, secret) + Si.Mod(Si, prime) + } } return &KEM{P: P, Q: Q}, nil @@ -370,10 +384,6 @@ func (priv *PrivateKey) Sign(digest []byte) (sign *Signature, err error) { S2Pub := new(big.Int).Mul(beta, priv.S2) S2Pub.Mod(S2Pub, prime) - // Initiate V, U - V := make([]*big.Int, len(priv.P)) - U := make([]*big.Int, len(priv.Q)) - // make K >= L+ 32 K := priv.S1.BitLen() if priv.S2.BitLen() > K { @@ -382,11 +392,14 @@ func (priv *PrivateKey) Sign(digest []byte) (sign *Signature, err error) { K += 32 R := new(big.Int).Exp(big.NewInt(2), big.NewInt(int64(K)), nil) + // Initiate V, U + V := make([]*big.Int, len(priv.P[0])) + U := make([]*big.Int, len(priv.Q[0])) for i := 0; i < len(V); i++ { - V[i] = new(big.Int).Mul(priv.Q[i], R) + V[i] = new(big.Int).Mul(priv.Q[0][i], R) V[i].Quo(V[i], priv.S2) - U[i] = new(big.Int).Mul(priv.P[i], R) + U[i] = new(big.Int).Mul(priv.P[0][i], R) U[i].Quo(U[i], priv.S1) } @@ -410,7 +423,7 @@ func (priv *PrivateKey) Public() *PublicKey { // Order returns the polynomial order of the private key. func (priv *PrivateKey) Order() int { - return len(priv.PublicKey.P) - 2 + return len(priv.PublicKey.P[0]) - 2 } // VerifySignature verifies the signature of the message digest using the public key and given prime @@ -456,10 +469,10 @@ func verifySignature(sig *Signature, digest []byte, pub *PublicKey, prime *big.I Q := make([]*big.Int, len(sig.U)) P := make([]*big.Int, len(sig.V)) for i := 0; i < len(Q); i++ { - Q[i] = new(big.Int).Mul(pub.Q[i], sig.Beta) + Q[i] = new(big.Int).Mul(pub.Q[0][i], sig.Beta) Q[i].Mod(Q[i], prime) - P[i] = new(big.Int).Mul(pub.P[i], sig.Beta) + P[i] = new(big.Int).Mul(pub.P[0][i], sig.Beta) P[i].Mod(P[i], prime) } @@ -508,7 +521,7 @@ func verifySignature(sig *Signature, digest []byte, pub *PublicKey, prime *big.I func createCoPrimePair(polyTerms int, p *big.Int) (R *big.Int, S *big.Int, err error) { one := big.NewInt(1) - bitLength := 2*p.BitLen() + big.NewInt(int64(polyTerms)).BitLen() + bitLength := 2*p.BitLen() + big.NewInt(int64(polyTerms)*MULTIVARIATE).BitLen() L := big.NewInt(1) L.Lsh(L, uint(bitLength))