diff --git a/src/Streamly/Internal/Data/Stream/IsStream/Top.hs b/src/Streamly/Internal/Data/Stream/IsStream/Top.hs index b2ee66fe5a..15f3e5ced1 100644 --- a/src/Streamly/Internal/Data/Stream/IsStream/Top.hs +++ b/src/Streamly/Internal/Data/Stream/IsStream/Top.hs @@ -43,7 +43,7 @@ module Streamly.Internal.Data.Stream.IsStream.Top , joinLeftMerge , hashLeftJoin , outerJoin - , mergeOuterJoin + , joinOuterMerge , hashOuterJoin ) where @@ -473,11 +473,14 @@ hashOuterJoin _eq _s1 _s2 = undefined -- -- Time: O(m + n) -- --- /Unimplemented/ -{-# INLINE mergeOuterJoin #-} -mergeOuterJoin :: -- Monad m => +-- /Pre-release/ +{-# INLINE joinOuterMerge #-} +joinOuterMerge :: (IsStream t, MonadIO m, Eq a, Eq b) => (a -> b -> Ordering) -> t m a -> t m b -> t m (Maybe a, Maybe b) -mergeOuterJoin _eq _s1 _s2 = undefined +joinOuterMerge eq s1 = + IsStream.fromStreamD + . StreamD.joinOuterMerge eq (IsStream.toStreamD s1) + . IsStream.toStreamD ------------------------------------------------------------------------------ -- Set operations (special joins) diff --git a/src/Streamly/Internal/Data/Stream/StreamD/Nesting.hs b/src/Streamly/Internal/Data/Stream/StreamD/Nesting.hs index 6e88e91f2a..b5dcf92817 100644 --- a/src/Streamly/Internal/Data/Stream/StreamD/Nesting.hs +++ b/src/Streamly/Internal/Data/Stream/StreamD/Nesting.hs @@ -1,3 +1,4 @@ +{-#LANGUAGE GADTs #-} -- | -- Module : Streamly.Internal.Data.Stream.StreamD.Nesting -- Copyright : (c) 2018 Composewell Technologies @@ -147,6 +148,7 @@ module Streamly.Internal.Data.Stream.StreamD.Nesting , differenceBySorted , joinInnerMerge , joinLeftMerge + , joinOuterMerge ) where @@ -157,6 +159,7 @@ import Control.Monad.Catch (MonadThrow, throwM) import Control.Monad.IO.Class (MonadIO(..)) import Data.Bits (shiftR, shiftL, (.|.), (.&.)) import Data.IORef +import Data.Maybe (isJust) #if __GLASGOW_HASKELL__ >= 801 import Data.Functor.Identity ( Identity ) #endif @@ -4011,4 +4014,700 @@ joinLeftMerge cmp (Stream stepa ta) (Stream stepb tb) = step _ (_, _, _, _, _, _, _, _, _) = do -- liftIO $ print "Step 11" - return Stop \ No newline at end of file + return Stop + +------------------------------------------------------------------------------- +-- Outer Join sorted streams -------------------------------------------------- +------------------------------------------------------------------------------- + +data OuterJoinStep = + OJS_INIT -- initial step to initialize the duplicate items list + | OJS_PULL_RIGHT -- pull the data from right stream + | OJS_PULL_RIGHT_DUP -- pull the data from right stram for duplicate match + | OJS_PULL_LEFT_INIT -- pull the data from left stream very first time + | OJS_PULL_LEFT_RUNNING -- pull the data from left stream after first time + | OJS_COMPARE -- compare the data of left stream with right one + | OJS_COMPARE_DUP -- compare data from the right stream to find the duplicates + | OJS_BUFF_PREP -- cache the duplicate data in right stream as list + | OJS_BUFF_PAIR -- join the left stream data with cached data (cartesian) + | OJS_BUFF_RESET -- empty the cache + | OJS_LEFT_REMAINS -- right stream is finished so drain out the left stream + | OJS_RIGHT_REMAINS -- left stream is finished so drain out the right stream + +-- +-- >>> toList $ Top.joinOuterMerge compare (fromList [1,2,2]) (fromList [2,2,3,4]) +-- [(Just 1,Nothing),(Just 2,Just 2),(Just 2,Just 2),(Just 2,Just 2),(Just 2,Just 2),(Nothing,Just 3),(Nothing,Just 4)] +-- +-- In case of duplicate elements in right stream we need to cache it so the duplicate matching +-- element of left stream will pair up with the same list. +-- +{-# INLINE_NORMAL joinOuterMerge #-} +joinOuterMerge :: (MonadIO m, Eq a, Eq b) => + (a -> b -> Ordering) + -> Stream m a + -> Stream m b + -> Stream m (Maybe a, Maybe b) + +joinOuterMerge cmp (Stream stepa ta) (Stream stepb tb) = + Stream + step + ( OJS_INIT -- OuterJoinStep + , Just ta -- state of left stream + , Just tb -- state of right stream + , Nothing -- current value of the left stream + , Nothing -- current value of the right stream + , Nothing -- previous value of the left stream + , Nothing -- previous value of the right stream + , Nothing -- IORef to list of repeated elements in right stream + , 0::Int -- current index of list of repeated elements + ) + + where + {-# INLINE_LATE step #-} + + -- step 0 + -- Initialize the duplicate elements in right stream as a reference to list + step + _gst + ( OJS_INIT + , Just sa + , Just sb + , Nothing + , Nothing + , Nothing + , Nothing + , Nothing + , idx + ) = + do + ref <- liftIO $ newIORef [] + return $ + Skip + ( OJS_PULL_LEFT_INIT + , Just sa + , Just sb + , Nothing + , Nothing + , Nothing + , Nothing + , Just ref + , idx + ) + + -- step 1 + -- pull the data from left stream very first time + step + gst + ( OJS_PULL_LEFT_INIT + , Just sa + , Just sb + , Nothing + , Nothing + , pa + , pb + , buff + , idx + ) = + do + r <- stepa (adaptState gst) sa + return $ case r of + Yield a' sa' -> + Skip + ( OJS_PULL_RIGHT + , Just sa' + , Just sb + , Just a' + , Nothing + , pa + , pb + , buff + , idx + ) + Skip sa' -> + Skip + ( OJS_PULL_LEFT_INIT + , Just sa' + , Just sb + , Nothing + , Nothing + , pa + , pb + , buff + , idx + ) + Stop -> + Skip + ( OJS_RIGHT_REMAINS + , Nothing + , Just sb + , Nothing + , Nothing + , Nothing + , Nothing + , Nothing + , idx + ) + + -- step 2 + -- pull the data from left stream apart form first time + step + gst + ( OJS_PULL_LEFT_RUNNING + , Just sa + , sb + , Just a + , Just b + , pa + , pb + , buff + , idx + ) = + do + r <- stepa (adaptState gst) sa + return $ case r of + Yield a' sa' -> + Skip + ( OJS_COMPARE + , Just sa' + , sb + , Just a' + , Just b + , pa + , pb + , buff + , idx + ) + Skip sa' -> + Skip + ( OJS_PULL_LEFT_RUNNING + , Just sa' + , sb + , Just a + , Just b + , pa + , pb + , buff + , idx + ) + Stop -> + Yield + (Nothing, Just b) + ( OJS_RIGHT_REMAINS + , Nothing + , sb + , Nothing + , Nothing + , Nothing + , Nothing + , Nothing + , idx + ) + + -- step 3 + -- pull the data from right stream and compare with data from left stream + step gst (OJS_PULL_RIGHT, Just sa, Just sb, Just a, b, pa, pb, buff, idx) = + do + r <- stepb (adaptState gst) sb + return $ case r of + Yield b' sb' -> + Skip + ( OJS_COMPARE + , Just sa + , Just sb' + , Just a + , Just b' + , pa + , Just b' + , buff + , idx + ) + Skip sb' -> + Skip + ( OJS_PULL_RIGHT + , Just sa + , Just sb' + , Just a + , b + , pa + , pb + , buff + , idx + ) + Stop -> + Yield + (Just a, Nothing) + ( OJS_LEFT_REMAINS + , Just sa + , Nothing + , Just a + , Nothing + , Nothing + , Nothing + , Nothing + , idx + ) + + -- step 4 + -- compare the data from left with right stream + step + _gst + (OJS_COMPARE, Just sa, Just sb, Just a, Just b, pa, pb, buff, idx) = + do + let res = cmp a b + return $ case res of + -- a < b , since b is increasing order so a will never match with b + LT -> + Yield + (Just a, Nothing) + ( OJS_PULL_LEFT_RUNNING + , Just sa + , Just sb + , Just a + , Just b + , pa + , pb + , buff + , idx + ) + + -- cache the duplicate items in right stream + EQ -> + Skip + ( OJS_BUFF_PREP + , Just sa + , Just sb + , Just a + , Just b + , Just a + , Just b + , buff + , idx + ) + + -- a > b , since b is increasing order so look for bigger data + GT -> + Yield + (Nothing, Just b) + ( OJS_PULL_RIGHT_DUP + , Just sa + , Just sb + , Just a + , Just b + , pa + , Just b + , buff + , idx + ) + + -- step 5 + -- compare b with previous b to find duplicates in the right stream + step + _gst + ( OJS_COMPARE_DUP + , Just sa + , sb + , Just a + , Just b + , pa + , Just pb + , buff + , idx + ) = + return $ + if b == pb + then + Yield + (Nothing, Just b) + ( OJS_PULL_RIGHT_DUP + , Just sa + , sb + , Just a + , Just b + , pa + , Just pb + , buff + , idx + ) + else + Skip + ( OJS_COMPARE + , Just sa + , sb + , Just a + , Just b + , pa + , Just b + , buff + , idx + ) + + -- step 6 + -- pull the data from right stream to get compared for duplicates + step + gst + ( OJS_PULL_RIGHT_DUP + , Just sa + , Just sb + , Just a + , b + , pa + , pb + , buff + , idx + ) = + do + r <- stepb (adaptState gst) sb + return $ + case r of + Yield b' sb' -> + Skip + ( OJS_COMPARE_DUP + , Just sa + , Just sb' + , Just a + , Just b' + , Just a + , pb + , buff + , idx + ) + + Skip sb' -> + Skip + ( OJS_PULL_RIGHT_DUP + , Just sa + , Just sb' + , Just a + , b + , pa + , pb + , buff + , idx + ) + Stop -> + Yield + (Just a, Nothing) + ( OJS_LEFT_REMAINS + , Just sa + , Nothing + , Just a + , Nothing + , Nothing + , Nothing + , Nothing + , idx + ) + + -- step 7 + -- cache the duplicate elements in a list + step + gst + ( OJS_BUFF_PREP + , Just sa + , Just sb + , Just a + , Just b + , pa + , Just pb + , Just buff + , idx + ) = + do + liftIO $ modifyIORef' buff (b : ) + r <- stepb (adaptState gst) sb + case r of + Yield b' sb' -> do + if b' == pb + then do + return $ + Skip + ( OJS_BUFF_PREP + , Just sa + , Just sb' + , Just a + , Just b' + , pa + , Just b + , Just buff + , idx + ) + else + return $ + Skip + ( OJS_BUFF_PAIR + , Just sa + , Just sb' + , Just a + , Just b' + , pa + , Just b' + , Just buff + , 0 + ) + Skip sb' -> + return $ + Skip + ( OJS_PULL_RIGHT + , Just sa + , Just sb' + , Just a + , Just b + , pa + , Just pb + , Just buff + , idx + ) + Stop -> + return $ + Skip + ( OJS_BUFF_PAIR + , Just sa + , Just sb + , Just a + , Nothing + , pa + , Just pb + , Just buff + , 0 + ) + + -- step 8 + -- pairing with buffed elements (only when repeatation is over) + step + _gst + ( OJS_BUFF_PAIR + , Just sa + , Just sb + , Just a + , b + , pa + , Just pb + , Just buff + , idx + ) = + do + bl <- liftIO $ readIORef buff + if idx < length bl + then + return $ + Yield + (Just a, Just (bl !! idx)) + ( OJS_BUFF_PAIR + , Just sa + , Just sb + , Just a + , b + , pa + , Just pb + , Just buff + , idx+1 + ) + else + return $ + Skip + ( OJS_BUFF_RESET + , Just sa + , Just sb + , Just a + , b + , Just a + , Just pb + , Just buff + , 0 + ) + + -- step 9 + -- empty the cache of duplicate elements if there is no more duplicates in left stream + step + gst + ( OJS_BUFF_RESET + , Just sa + , Just sb + , Just a + , b + , Just pa + , pb + , Just buff + , idx + ) = + do + r <- stepa (adaptState gst) sa + case r of + Yield a' sa' -> + do + if a' == pa -- left stream has duplicate data so pair up with cached matching data + then + return $ + Skip + ( OJS_BUFF_PAIR + , Just sa' + , Just sb + , Just a' + , b + , Just a' + , pb + , Just buff + , idx + ) + else do + -- clear the cache + liftIO $ writeIORef buff [] + return $ + if isJust b + then + Skip + ( OJS_COMPARE + , Just sa' + , Just sb + , Just a' + , b + , Just a' + , pb + , Just buff + , idx + ) + else -- right stream is finished + Yield + (Just a', Nothing) + ( OJS_LEFT_REMAINS + , Just sa' + , Nothing + , Just a' + , Nothing + , Nothing + , Nothing + , Nothing + , idx + ) + Skip sa' -> + return $ + Skip + ( OJS_BUFF_RESET + , Just sa' + , Just sb + , Just a + , b + , Just pa + , pb + , Just buff + , idx + ) + Stop -> + return $ + if isJust b + then -- already has read the data emit and proceed + Yield + (Nothing, b) + ( OJS_RIGHT_REMAINS + , Nothing + , Just sb + , Nothing + , Nothing + , Nothing + , Nothing + , Nothing + , idx + ) + else + Skip -- read and emit till end + ( OJS_RIGHT_REMAINS + , Nothing + , Just sb + , Nothing + , Nothing + , Nothing + , Nothing + , Nothing + , idx + ) + + -- step 10 + -- right stream is finished, drain out left stream. + step + gst + ( OJS_LEFT_REMAINS + , Just sa + , Nothing + , Just a + , Nothing + , Nothing + , Nothing + , Nothing + , idx + ) = + do + r <- stepa (adaptState gst) sa + return $ case r of + Yield a' sa' -> + Yield + (Just a', Nothing) + ( OJS_LEFT_REMAINS + , Just sa' + , Nothing + , Just a' + , Nothing + , Nothing + , Nothing + , Nothing + , idx + ) + Skip sa' -> + Skip + ( OJS_LEFT_REMAINS + , Just sa' + , Nothing + , Just a + , Nothing + , Nothing + , Nothing + , Nothing + , idx + ) + Stop -> Stop + + -- step 11 + -- left stream is finished, drain out right stream. + step + gst + ( OJS_RIGHT_REMAINS + , Nothing + , Just sb + , Nothing + , Nothing + , Nothing + , Nothing + , Nothing + , idx + ) = + do -- read and emit till end + r <- stepb (adaptState gst) sb + return $ case r of + Yield b' sb' -> + Yield + (Nothing, Just b') + ( OJS_RIGHT_REMAINS + , Nothing + , Just sb' + , Nothing + , Nothing + , Nothing + , Nothing + , Nothing + , idx + ) + Skip sb' -> + Skip + ( OJS_RIGHT_REMAINS + , Nothing + , Just sb' + , Nothing + , Nothing + , Nothing + , Nothing + , Nothing + , idx + ) + Stop -> Stop + + -- step 12 + step _ (_, _, _, _, _, _, _, _, _) = return Stop diff --git a/test/Streamly/Test/Data/Stream/Top.hs b/test/Streamly/Test/Data/Stream/Top.hs index 2c0bd91d70..326dd39466 100644 --- a/test/Streamly/Test/Data/Stream/Top.hs +++ b/test/Streamly/Test/Data/Stream/Top.hs @@ -2,6 +2,7 @@ module Main (main) where import Data.List (elem, intersect, nub, sort, union, (\\)) +import Data.Maybe (isNothing) import Test.QuickCheck ( Gen , Property @@ -126,6 +127,37 @@ joinLeftMerge = then return (i, Just i) else return (i, Nothing) assert (v1 == v2) + +joinOuterMerge :: Property +joinOuterMerge = + forAll (listOf (chooseInt (min_value, max_value))) $ \ls0 -> + forAll (listOf (chooseInt (min_value, max_value))) $ \ls1 -> + monadicIO $ action (sort ls0) (sort (nub ls1)) + + where + + action ls0 ls1 = do + v1 <- + run + $ S.toList + $ Top.joinOuterMerge + compare + (S.fromList ls0) + (S.fromList ls1) + let v2 = do + i <- ls0 + if (elem i ls1) + then return (Just i, Just i) + else return (Just i, Nothing) + v3 = do + j <- ls1 + if (elem j ls0) + then return (Just j, Just j) + else return (Nothing, Just j) + v4 = filter (\(a1, _) -> isNothing a1) v3 + + assert (sort v1 == sort (v2 ++ v4)) + ------------------------------------------------------------------------------- moduleName :: String moduleName = "Data.Stream.Top" @@ -139,3 +171,4 @@ main = hspec $ do prop "differenceBySorted" Main.differenceBySorted prop "joinInnerMerge" Main.joinInnerMerge prop "joinLeftMerge" Main.joinLeftMerge + prop "joinOuterMerge" Main.joinOuterMerge