Skip to content

Commit

Permalink
use binary instead of cereal.
Browse files Browse the repository at this point in the history
  • Loading branch information
bohde committed May 28, 2015
1 parent f5a4086 commit 5b9daab
Show file tree
Hide file tree
Showing 8 changed files with 52 additions and 53 deletions.
2 changes: 1 addition & 1 deletion plugin/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import Data.Int
import Data.Monoid
import Data.ProtocolBuffers
import Data.ProtocolBuffers.Internal
import Data.Serialize
import Data.Binary
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Lazy as TextL
Expand Down
4 changes: 2 additions & 2 deletions protobuf.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ library
build-depends:
base >= 4.7 && < 5,
bytestring >= 0.9,
cereal >= 0.3,
binary >= 0.7,
data-binary-ieee754 >= 0.4,
deepseq >= 1.1,
mtl == 2.*,
Expand Down Expand Up @@ -89,7 +89,7 @@ test-suite protobuf-test
build-depends:
base >= 4.7 && < 5,
bytestring,
cereal,
binary,
containers,
hex,
mtl,
Expand Down
10 changes: 5 additions & 5 deletions src/Data/ProtocolBuffers.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
--import "GHC.Generics" ('GHC.Generics.Generic')
--import "GHC.TypeLits"
--import "Data.Monoid"
--import "Data.Serialize"
--import "Data.Binary"
--import "Data.Hex" -- cabal install hex (for testing)
--
-- data Foo = Foo
Expand All @@ -41,15 +41,15 @@
--
-- >>> let msg = Foo{field1 = putField 42, field2 = mempty, field3 = putField [True, False]}
--
-- To serialize a message first convert it into a 'Data.Serialize.Put' by way of 'encodeMessage'
-- and then to a 'Data.ByteString.ByteString' by using 'Data.Serialize.runPut'. Lazy
-- 'Data.ByteString.Lazy.ByteString' serialization is done with 'Data.Serialize.runPutLazy'.
-- To serialize a message first convert it into a 'Data.Binary.Put' by way of 'encodeMessage'
-- and then to a 'Data.ByteString.ByteString' by using 'Data.Binary.runPut'. Lazy
-- 'Data.ByteString.Lazy.ByteString' serialization is done with 'Data.Binary.runPutLazy'.
--
-- >>> fmap hex runPut $ encodeMessage msg
-- "082A18011800"
--
-- Decoding is done with the inverse functions: 'decodeMessage'
-- and 'Data.Serialize.runGet', or 'Data.Serialize.runGetLazy'.
-- and 'Data.Binary.runGet'.
--
-- >>> runGet decodeMessage =<< unhex "082A18011800" :: Either String Foo
-- Right
Expand Down
16 changes: 8 additions & 8 deletions src/Data/ProtocolBuffers/Decode.hs
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,22 @@ module Data.ProtocolBuffers.Decode

import Control.Applicative
import Control.Monad
import qualified Data.ByteString as B
import Data.Foldable
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HashMap
import Data.Int (Int32, Int64)
import Data.Maybe (fromMaybe)
import Data.Monoid
import Data.Proxy
import Data.Serialize.Get
import Data.Binary.Get
import Data.Traversable (traverse)

import GHC.Generics
import GHC.TypeLits

import Data.ProtocolBuffers.Types
import Data.ProtocolBuffers.Wire
import qualified Data.ByteString.Lazy as LBS

-- |
-- Decode a Protocol Buffers message.
Expand All @@ -51,12 +51,12 @@ decodeLengthPrefixedMessage :: Decode a => Get a
{-# INLINE decodeLengthPrefixedMessage #-}
decodeLengthPrefixedMessage = do
len :: Int64 <- getVarInt
bs <- getBytes $ fromIntegral len
case runGetState decodeMessage bs 0 of
Right (val, bs')
| B.null bs' -> return val
| otherwise -> fail $ "Unparsed bytes leftover in decodeLengthPrefixedMessage: " ++ show (B.length bs')
Left err -> fail err
bs <- getByteString $ fromIntegral len
case runGetOrFail decodeMessage (LBS.fromStrict bs) of
Right (bs', _, val)
| LBS.null bs' -> return val
| otherwise -> fail $ "Unparsed bytes leftover in decodeLengthPrefixedMessage: " ++ show (LBS.length bs')
Left (_, _, err) -> fail err

class Decode (a :: *) where
decode :: HashMap Tag [WireField] -> Get a
Expand Down
8 changes: 4 additions & 4 deletions src/Data/ProtocolBuffers/Encode.hs
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ module Data.ProtocolBuffers.Encode
, GEncode
) where

import qualified Data.ByteString as B
import Data.Foldable
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HashMap
import Data.Proxy
import Data.Serialize.Put
import Data.Binary.Put

import GHC.Generics
import GHC.TypeLits

import Data.ProtocolBuffers.Types
import Data.ProtocolBuffers.Wire
import qualified Data.ByteString.Lazy as LBS

-- |
-- Encode a Protocol Buffers message.
Expand All @@ -36,8 +36,8 @@ encodeLengthPrefixedMessage :: Encode a => a -> Put
{-# INLINE encodeLengthPrefixedMessage #-}
encodeLengthPrefixedMessage msg = do
let msg' = runPut $ encodeMessage msg
putVarUInt $ B.length msg'
putByteString msg'
putVarUInt $ LBS.length msg'
putLazyByteString msg'

class Encode (a :: *) where
encode :: a -> Put
Expand Down
10 changes: 4 additions & 6 deletions src/Data/ProtocolBuffers/Message.hs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ import Control.Applicative
import Control.DeepSeq (NFData(..))
import Data.Foldable
import Data.Monoid
import Data.Serialize.Get
import Data.Serialize.Put
import Data.Binary.Get
import Data.Binary.Put
import Data.Traversable

import GHC.Generics
Expand All @@ -29,6 +29,7 @@ import Data.ProtocolBuffers.Decode
import Data.ProtocolBuffers.Encode
import Data.ProtocolBuffers.Types
import Data.ProtocolBuffers.Wire
import qualified Data.ByteString.Lazy as LBS

-- |
-- The way to embed a message within another message.
Expand Down Expand Up @@ -160,10 +161,7 @@ instance (Foldable f, Encode m) => EncodeWire (f (Message m)) where
traverse_ (encodeWire t . runPut . encode . runMessage)

instance Decode m => DecodeWire (Message m) where
decodeWire (DelimitedField _ bs) =
case runGet decodeMessage bs of
Right val -> pure $ Message val
Left err -> fail $ "Embedded message decoding failed: " ++ show err
decodeWire (DelimitedField _ bs) = pure $ Message $ runGet decodeMessage $ LBS.fromStrict bs
decodeWire _ = empty

-- | Iso: @ 'FieldType' ('Required' n ('Message' a)) = a @
Expand Down
20 changes: 10 additions & 10 deletions src/Data/ProtocolBuffers/Wire.hs
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,17 @@ import Control.Applicative
import Data.Bits
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as LBS
import Data.Foldable
import Data.Int
import Data.Monoid
import Data.Serialize.Get
import Data.Serialize.IEEE754
import Data.Serialize.Put
import Data.Binary.Get
import Data.Binary.IEEE754
import Data.Binary.Put
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Data.Typeable
import Data.Word
import Data.Binary.IEEE754 (wordToDouble, wordToFloat)

import Data.ProtocolBuffers.Types

Expand All @@ -60,7 +60,7 @@ data WireField
deriving (Eq, Ord, Show, Typeable)

getVarintPrefixedBS :: Get ByteString
getVarintPrefixedBS = getBytes =<< getVarInt
getVarintPrefixedBS = getByteString =<< getVarInt

putVarintPrefixedBS :: ByteString -> Put
putVarintPrefixedBS bs = putVarUInt (B.length bs) >> putByteString bs
Expand Down Expand Up @@ -259,6 +259,9 @@ instance DecodeWire Double where
decodeWire (Fixed64Field _ val) = pure $ wordToDouble val
decodeWire _ = empty

instance EncodeWire LBS.ByteString where
encodeWire t val = putWireTag t 2 >> putVarUInt (LBS.length val) >> putLazyByteString val

instance EncodeWire ByteString where
encodeWire t val = putWireTag t 2 >> putVarUInt (B.length val) >> putByteString val

Expand All @@ -284,10 +287,7 @@ instance DecodeWire T.Text where

decodePackedList :: Get a -> WireField -> Get [a]
{-# INLINE decodePackedList #-}
decodePackedList g (DelimitedField _ bs) =
case runGet (many g) bs of
Right val -> return val
Left err -> fail err
decodePackedList g (DelimitedField _ bs) = return $ runGet (many g) (LBS.fromStrict bs)
decodePackedList _ _ = empty

-- |
Expand All @@ -296,7 +296,7 @@ encodePackedList :: Tag -> Put -> Put
{-# INLINE encodePackedList #-}
encodePackedList t p
| bs <- runPut p
, not (B.null bs) = encodeWire t bs
, not (LBS.null bs) = encodeWire t (LBS.toStrict bs)
| otherwise = pure ()

instance EncodeWire (PackedList (Value Int32)) where
Expand Down
35 changes: 18 additions & 17 deletions tests/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,14 @@ import Data.Int
import Data.IntSet (IntSet)
import qualified Data.IntSet as IntSet
import Data.Monoid
import Data.Serialize (Get, Putter, runGet, runPut)
import Data.Binary.Get (Get, runGet)
import Data.Binary.Put (Put, runPut)
import Data.Proxy
import Data.Text (Text)
import Data.Typeable
import Data.Word
import qualified Data.ByteString.Lazy as LBS
import Data.Binary.Get (runGetOrFail)

main :: IO ()
main = defaultMain tests
Expand Down Expand Up @@ -301,9 +304,9 @@ prop_wire _ = label ("prop_wire :: " ++ show (typeOf (undefined :: a))) $ do
field <- getWireField
guard $ tag == wireFieldTag field
decodeWire field
case runGet dec bs of
Right val' -> return $ val == val'
Left err -> fail err
case runGetOrFail dec bs of
Right (_, _, val') -> return $ val == val'
Left (_, _, err) -> fail err

prop_generic :: Gen Property
prop_generic = do
Expand All @@ -314,33 +317,31 @@ prop_generic_length_prefixed :: Gen Property
prop_generic_length_prefixed = do
msg <- HashMap.fromListWith (++) . fmap (\ c -> (wireFieldTag c, [c])) <$> listOf1 arbitrary
let bs = runPut $ encodeLengthPrefixedMessage (msg :: HashMap Tag [WireField])
case runGet decodeLengthPrefixedMessage bs of
Right msg' -> return $ counterexample "foo" $ msg == msg'
Left err -> fail err
case runGetOrFail decodeLengthPrefixedMessage bs of
Right (_, _, msg') -> return $ counterexample "foo" $ msg == msg'
Left (_, _, err) -> fail err

prop_roundtrip_msg :: (Eq a, Encode a, Decode a) => a -> Gen Property
prop_roundtrip_msg msg = do
let bs = runPut $ encodeMessage msg
case runGet decodeMessage bs of
Right msg' -> return . property $ msg == msg'
Left err -> fail err
msg' -> return . property $ msg == msg'

prop_varint_prefixed_bytestring :: Gen Property
prop_varint_prefixed_bytestring = do
bs <- B.pack <$> arbitrary
prop_roundtrip_value getVarintPrefixedBS putVarintPrefixedBS bs

prop_roundtrip_value :: (Eq a, Show a) => Get a -> Putter a -> a -> Gen Property
prop_roundtrip_value :: (Eq a, Show a) => Get a -> (a -> Put) -> a -> Gen Property
prop_roundtrip_value get put val = do
let bs = runPut (put val)
case runGet get bs of
Right val' -> return $ val === val'
Left err -> fail err
val' -> return $ val === val'

prop_encode_fail :: Encode a => a -> Gen Prop
prop_encode_fail msg = unProperty $ ioProperty $ do
res <- try . evaluate . runPut $ encodeMessage msg
return $ case res :: Either SomeException B.ByteString of
return $ case res :: Either SomeException LBS.ByteString of
Left _ -> True
Right _ -> False

Expand Down Expand Up @@ -409,11 +410,11 @@ prop_opt _ = label ("prop_opt :: " ++ show (typeOf (undefined :: a))) $ do
testSpecific :: (Eq a, Show a, Encode a, Decode a) => a -> B.ByteString -> IO ()
testSpecific msg ref = do
let bs = runPut $ encodeMessage msg
assertEqual "Encoded message mismatch" bs ref
assertEqual "Encoded message mismatch" (LBS.toStrict bs) ref

case runGet decodeMessage bs of
Right msg' -> assertEqual "Decoded message mismatch" msg msg'
Left err -> assertFailure err
case runGetOrFail decodeMessage bs of
Right (_, _, msg') -> assertEqual "Decoded message mismatch" msg msg'
Left (_, _, err) -> assertFailure err

data Test1 = Test1{test1_a :: Required 1 (Value Int32)} deriving (Generic)
deriving instance Eq Test1
Expand Down

0 comments on commit 5b9daab

Please sign in to comment.