diff --git a/cardano-crypto-class/cardano-crypto-class.cabal b/cardano-crypto-class/cardano-crypto-class.cabal index 27653d86c..e134f4eb0 100644 --- a/cardano-crypto-class/cardano-crypto-class.cabal +++ b/cardano-crypto-class/cardano-crypto-class.cabal @@ -132,19 +132,27 @@ library pkgconfig-depends: libsecp256k1 cpp-options: -DSECP256K1_ENABLED -test-suite test-memory-example +-- test-suite test-memory-example +-- import: base, project-config +-- -- Temporarily removing this as it is breaking the CI, and +-- -- we don't see the benefit. Will circle back to this to decide +-- -- whether to modify or completely remove. +-- buildable: False +-- type: exitcode-stdio-1.0 +-- hs-source-dirs: memory-example +-- main-is: Main.hs +-- build-depends: +-- , base +-- , bytestring +-- , cardano-crypto-class + +-- if (os(linux) || os(osx)) +-- build-depends: unix + +executable run-msm import: base, project-config - -- Temporarily removing this as it is breaking the CI, and - -- we don't see the benefit. Will circle back to this to decide - -- whether to modify or completely remove. - buildable: False - type: exitcode-stdio-1.0 - hs-source-dirs: memory-example + hs-source-dirs: exe main-is: Main.hs build-depends: , base - , bytestring , cardano-crypto-class - - if (os(linux) || os(osx)) - build-depends: unix diff --git a/cardano-crypto-class/exe/Main.hs b/cardano-crypto-class/exe/Main.hs new file mode 100644 index 000000000..7fe072171 --- /dev/null +++ b/cardano-crypto-class/exe/Main.hs @@ -0,0 +1,21 @@ +{-# LANGUAGE TypeApplications #-} + +module Main where + +import Cardano.Crypto.EllipticCurve.BLS12_381.Internal +import System.IO.Unsafe (unsafePerformIO) +import qualified Data.List.NonEmpty as NonEmpty + +main :: IO () +main = do + let g1 = blsGenerator @Curve1 + pointsCurve1 = [g1,g1,g1,g1] + scalars = map (unsafePerformIO . scalarFromInteger) [0,1,2,3] + poinsAndScalars = NonEmpty.fromList $ zip pointsCurve1 scalars + res1 = blsMSM poinsAndScalars + print $ blsCompress res1 + let res2 = blsMult g1 6 + print $ blsCompress res2 + if res1 == res2 + then putStrLn "Success" + else putStrLn "Failure" \ No newline at end of file diff --git a/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381.hs b/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381.hs index 7f4e83df0..cd9121bec 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381.hs @@ -23,6 +23,7 @@ module Cardano.Crypto.EllipticCurve.BLS12_381 ( blsMult, blsCneg, blsNeg, + blsMSM, blsCompress, blsSerialize, blsUncompress, diff --git a/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381/Internal.hs b/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381/Internal.hs index e7aacdbf0..92a3799b8 100644 --- a/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381/Internal.hs +++ b/cardano-crypto-class/src/Cardano/Crypto/EllipticCurve/BLS12_381/Internal.hs @@ -129,6 +129,7 @@ module Cardano.Crypto.EllipticCurve.BLS12_381.Internal ( blsMult, blsCneg, blsNeg, + blsMSM, blsCompress, blsSerialize, blsUncompress, @@ -175,6 +176,8 @@ import Foreign.Marshal.Utils (copyBytes) import Foreign.Ptr (Ptr, castPtr, nullPtr, plusPtr) import Foreign.Storable (peek) import System.IO.Unsafe (unsafePerformIO) +import qualified Data.List.NonEmpty as NonEmpty +import Control.Monad (zipWithM_) ---- Phantom Types @@ -189,10 +192,14 @@ type Point1Ptr = PointPtr Curve1 type Point2Ptr = PointPtr Curve2 newtype AffinePtr curve = AffinePtr (Ptr Void) +newtype AffinePtrVector curve = AffinePtrVector (Ptr Void) type Affine1Ptr = AffinePtr Curve1 type Affine2Ptr = AffinePtr Curve2 +type Affine1PtrList = AffinePtrVector Curve1 +type Affine2PtrList = AffinePtrVector Curve2 + newtype PTPtr = PTPtr (Ptr Void) unsafePointFromPointPtr :: PointPtr curve -> Point curve @@ -288,6 +295,19 @@ withNewAffine_ = fmap fst . withNewAffine withNewAffine' :: BLS curve => (AffinePtr curve -> IO a) -> IO (Affine curve) withNewAffine' = fmap snd . withNewAffine +-- Helper: Converts a list of affine points to a contiguous memory block +withAffineList :: forall curve a. BLS curve => [Affine curve] -> (AffinePtrVector curve -> IO a) -> IO a +withAffineList affines go = do + let numAffines = length affines + let sizeAffine' = sizeAffine (Proxy @curve) + allocaBytes (numAffines * sizeAffine') $ \ptr -> do + -- Copy each affine point to the memory block + let copyAffineAtIx ix affine = + withAffine affine $ \(AffinePtr aPtr) -> + copyBytes (ptr `plusPtr` (ix * sizeAffine')) (castPtr aPtr) sizeAffine' + zipWithM_ copyAffineAtIx [0..] affines + go (AffinePtrVector ptr) + withPT :: PT -> (PTPtr -> IO a) -> IO a withPT (PT pt) go = withForeignPtr pt (go . PTPtr) @@ -317,6 +337,9 @@ class BLS curve where c_blst_mult :: PointPtr curve -> PointPtr curve -> ScalarPtr -> CSize -> IO () c_blst_cneg :: PointPtr curve -> Bool -> IO () + c_blst_scratch_sizeof :: Proxy curve -> CSize -> CSize + c_blst_mult_pippenger :: PointPtr curve -> AffinePtrVector curve -> CSize -> ScalarPtrList -> CSize -> ScratchPtr -> IO () + c_blst_hash :: PointPtr curve -> Ptr CChar -> CSize -> Ptr CChar -> CSize -> Ptr CChar -> CSize -> IO () c_blst_compress :: Ptr CChar -> PointPtr curve -> IO () @@ -345,6 +368,9 @@ instance BLS Curve1 where c_blst_mult = c_blst_p1_mult c_blst_cneg = c_blst_p1_cneg + c_blst_scratch_sizeof _ = c_blst_p1s_mult_pippenger_scratch_sizeof + c_blst_mult_pippenger = c_blst_p1s_mult_pippenger + c_blst_hash = c_blst_hash_to_g1 c_blst_compress = c_blst_p1_compress c_blst_serialize = c_blst_p1_serialize @@ -373,6 +399,9 @@ instance BLS Curve2 where c_blst_mult = c_blst_p2_mult c_blst_cneg = c_blst_p2_cneg + c_blst_scratch_sizeof _ = c_blst_p2s_mult_pippenger_scratch_sizeof + c_blst_mult_pippenger = c_blst_p2s_mult_pippenger + c_blst_hash = c_blst_hash_to_g2 c_blst_compress = c_blst_p2_compress c_blst_serialize = c_blst_p2_serialize @@ -428,6 +457,18 @@ withNewScalar_ = fmap fst . withNewScalar withNewScalar' :: (ScalarPtr -> IO a) -> IO Scalar withNewScalar' = fmap snd . withNewScalar +-- Helper: Converts a list of scalars to a contiguous memory block +withScalarList :: [Scalar] -> (ScalarPtrList -> IO a) -> IO a +withScalarList scalars go = do + let numScalars = length scalars + allocaBytes (numScalars * sizeScalar) $ \ptr -> do + -- Copy each scalar to the memory block + let copyScalarAtIx ix scalar = + withScalar scalar $ \(ScalarPtr sPtr) -> + copyBytes (ptr `plusPtr` (ix * sizeScalar)) (castPtr sPtr) sizeScalar + zipWithM_ copyScalarAtIx [0..] scalars + go (ScalarPtrList ptr) + cloneScalar :: Scalar -> IO Scalar cloneScalar (Scalar a) = do b <- mallocForeignPtrBytes sizeScalar @@ -512,7 +553,9 @@ scalarFromInteger n = do ---- Unsafe types newtype ScalarPtr = ScalarPtr (Ptr Void) +newtype ScalarPtrList = ScalarPtrList (Ptr Void) newtype FrPtr = FrPtr (Ptr Void) +newtype ScratchPtr = ScratchPtr (Ptr Void) ---- Raw Scalar / Fr functions @@ -555,6 +598,9 @@ foreign import ccall "blst_p1_generator" c_blst_p1_generator :: Point1Ptr foreign import ccall "blst_p1_is_equal" c_blst_p1_is_equal :: Point1Ptr -> Point1Ptr -> IO Bool foreign import ccall "blst_p1_is_inf" c_blst_p1_is_inf :: Point1Ptr -> IO Bool +foreign import ccall "blst_p1s_mult_pippenger_scratch_sizeof" c_blst_p1s_mult_pippenger_scratch_sizeof :: CSize -> CSize +foreign import ccall "blst_p1s_mult_pippenger" c_blst_p1s_mult_pippenger :: Point1Ptr -> Affine1PtrList -> CSize -> ScalarPtrList -> CSize -> ScratchPtr -> IO () + ---- Raw Point2 functions foreign import ccall "size_blst_p2" c_size_blst_p2 :: CSize @@ -582,6 +628,9 @@ foreign import ccall "blst_p2_generator" c_blst_p2_generator :: Point2Ptr foreign import ccall "blst_p2_is_equal" c_blst_p2_is_equal :: Point2Ptr -> Point2Ptr -> IO Bool foreign import ccall "blst_p2_is_inf" c_blst_p2_is_inf :: Point2Ptr -> IO Bool +foreign import ccall "blst_p2s_mult_pippenger_scratch_sizeof" c_blst_p2s_mult_pippenger_scratch_sizeof :: CSize -> CSize +foreign import ccall "blst_p2s_mult_pippenger" c_blst_p2s_mult_pippenger :: Point2Ptr -> Affine2PtrList -> CSize -> ScalarPtrList -> CSize -> ScratchPtr -> IO () + ---- Affine operations foreign import ccall "size_blst_affine1" c_size_blst_affine1 :: CSize @@ -824,7 +873,8 @@ blsZero = error $ "Unexpected failure deserialising point at infinity on BLS12_381.G1: " ++ show err Right infinity -> infinity -- The zero point on this curve is chosen to be the point at infinity. - ---- Scalar / Fr operations + +---- Scalar / Fr operations scalarFromFr :: Fr -> IO Scalar scalarFromFr fr = @@ -875,6 +925,46 @@ scalarCanonical scalar = unsafePerformIO $ withScalar scalar c_blst_scalar_fr_check +---- MSM operations + +-- | Multi-scalar multiplication using the Pippenger algorithm. +-- The number of points must be equal or smaller than the number of scalars. +-- For reference, see the usage of the rust bindings: https://github.com/perturbing/blst/blob/master/bindings/rust/src/pippenger.rs#L143C1-L161C11 +-- Note that we only implement the single thread version of the algorithm. +blsMSM :: forall curve. BLS curve => NonEmpty.NonEmpty (Point curve, Scalar) -> Point curve +blsMSM psAndSs = + unsafePerformIO $ do + -- Split points and scalars into separate lists + let (affinePoints, scalars) = unzip $ NonEmpty.toList psAndSs + + -- Convert points to affine representations + let affinePoints' = map toAffine affinePoints + + -- Allocate memory for affine points and scalars + withAffineList affinePoints' $ \affineListPtr -> + withScalarList scalars $ \scalarListPtr -> do + + -- Calculate required scratch size + let numPoints :: CSize + numPoints = fromIntegral @Int @CSize $ NonEmpty.length psAndSs + scratchSize :: Int + scratchSize = fromIntegral @CSize @Int $ c_blst_scratch_sizeof (Proxy @curve) numPoints + affineSize = sizeAffine (Proxy @curve) + + -- Allocate scratch space + allocaBytes (scratchSize * affineSize) $ \scratchPtr -> do + + -- Allocate memory for the result point + withNewPoint' $ \resultPtr -> do + -- Perform the MSM + c_blst_mult_pippenger + resultPtr + affineListPtr + numPoints + scalarListPtr + 255 + (ScratchPtr scratchPtr) + ---- PT operations ptMult :: PT -> PT -> PT