diff --git a/src/Language/JVM/Parser.hs b/src/Language/JVM/Parser.hs index 91533fb..135a54a 100644 --- a/src/Language/JVM/Parser.hs +++ b/src/Language/JVM/Parser.hs @@ -9,6 +9,8 @@ Portability : portable Parser for the JVM bytecode format. -} +{-# LANGUAGE LambdaCase #-} + module Language.JVM.Parser ( -- * Class declarations Class @@ -132,6 +134,22 @@ import System.IO import Language.JVM.CFG import Language.JVM.Common +-- | Indicate parse failure with the given error message. Note that +-- failure in the 'Get' monad already tracks the number of bytes +-- consumed, so it is not necessary to include position information in +-- the error message. +failure :: String -> Get a +failure msg = fail msg + +-- | Run an inner parser on a bytestring, passing along any failures. +subParser :: Get a -> L.ByteString -> Get a +subParser g l = + case runGetOrFail g l of + Left (_, pos, msg) -> + failure ("Sub-parser failed at position " ++ show pos ++ ": " ++ msg) + Right (_, _, x) -> + pure x + -- Version of replicate with arguments convoluted for parser. replicateN :: (Integral b, Monad m) => m a -> b -> m [a] replicateN fn i = sequence (replicate (fromIntegral i) fn) @@ -144,22 +162,26 @@ showOnNewLines n (a : rest) = replicate n ' ' ++ a ++ "\n" ++ showOnNewLines n r ---------------------------------------------------------------------- -- Type -parseTypeDescriptor :: String -> (Type, String) -parseTypeDescriptor ('B' : rest) = (ByteType, rest) -parseTypeDescriptor ('C' : rest) = (CharType, rest) -parseTypeDescriptor ('D' : rest) = (DoubleType, rest) -parseTypeDescriptor ('F' : rest) = (FloatType, rest) -parseTypeDescriptor ('I' : rest) = (IntType, rest) -parseTypeDescriptor ('J' : rest) = (LongType, rest) -parseTypeDescriptor ('L' : rest) = split rest [] - where split (';' : rest') result = (ClassType (mkClassName (reverse result)), rest') - split (ch : rest') result = split rest' (ch : result) - split _ _ = error "internal: unable to parse type descriptor" -parseTypeDescriptor ('S' : rest) = (ShortType, rest) -parseTypeDescriptor ('Z' : rest) = (BooleanType, rest) -parseTypeDescriptor ('[' : rest) = (ArrayType tp, result) - where (tp, result) = parseTypeDescriptor rest -parseTypeDescriptor st = error ("Unexpected type descriptor string " ++ st) +parseTypeDescriptor :: String -> Maybe (Type, String) +parseTypeDescriptor str = + case str of + ('B' : rest) -> Just (ByteType, rest) + ('C' : rest) -> Just (CharType, rest) + ('D' : rest) -> Just (DoubleType, rest) + ('F' : rest) -> Just (FloatType, rest) + ('I' : rest) -> Just (IntType, rest) + ('J' : rest) -> Just (LongType, rest) + ('S' : rest) -> Just (ShortType, rest) + ('Z' : rest) -> Just (BooleanType, rest) + ('L' : rest) -> split rest [] + ('[' : rest) -> + do (tp, result) <- parseTypeDescriptor rest + Just (ArrayType tp, result) + _ -> Nothing + where + split (';' : rest') result = Just (ClassType (mkClassName (reverse result)), rest') + split (ch : rest') result = split rest' (ch : result) + split _ _ = Nothing ---------------------------------------------------------------------- -- Visibility @@ -177,13 +199,20 @@ instance Show Visibility where ---------------------------------------------------------------------- -- Method descriptors -parseMethodDescriptor :: String -> (Maybe Type, [Type]) -parseMethodDescriptor ('(' : rest) = impl rest [] - where impl ")V" types = (Nothing, reverse types) - impl (')' : rest') types = (Just $ fst $ parseTypeDescriptor rest', reverse types) - impl text types = let (tp, rest') = parseTypeDescriptor text - in impl rest' (tp : types) -parseMethodDescriptor _ = error "internal: unable to parse method descriptor" +parseMethodDescriptor :: String -> Maybe (Maybe Type, [Type]) +parseMethodDescriptor str = + case str of + ('(' : rest) -> impl rest [] + _ -> Nothing + where + impl ")V" types = Just (Nothing, reverse types) + impl (')' : rest') types = + case parseTypeDescriptor rest' of + Just (tp, "") -> Just (Just tp, reverse types) + _ -> Nothing + impl text types = + do (tp, rest') <- parseTypeDescriptor text + impl rest' (tp : types) unparseMethodDescriptor :: MethodKey -> String unparseMethodDescriptor (MethodKey _ paramTys retTy) = @@ -205,7 +234,7 @@ makeMethodKey :: String -- ^ Method name -> String -- ^ Method descriptor -> MethodKey makeMethodKey name descriptor = MethodKey name parameters returnType - where (returnType, parameters) = parseMethodDescriptor descriptor + where Just (returnType, parameters) = parseMethodDescriptor descriptor mainKey :: MethodKey mainKey = makeMethodKey "main" "([Ljava/lang/String;)V" @@ -232,22 +261,35 @@ data ConstantPoolInfo | Phantom deriving (Show) --- Parses array of bytes from Java string -getJavaString :: [Word8] -> String -getJavaString [] = [] -getJavaString (x : rest) - | (x .&. 0x80) == 0 = chr (fromIntegral x) : getJavaString rest -getJavaString (x : y : rest) +-- | Parse a string from a list of bytes, according to section 4.4.7 +-- of the JVM spec: "String content is encoded in modified UTF-8. +-- Modified UTF-8 strings are encoded so that code point sequences +-- that contain only non-null ASCII characters can be represented +-- using only 1 byte per code point, but all code points in the +-- Unicode codespace can be represented." +-- +-- "There are two differences between this format and the "standard" +-- UTF-8 format. First, the null character (char)0 is encoded using +-- the 2-byte format rather than the 1-byte format, so that modified +-- UTF-8 strings never have embedded nulls. Second, only the 1-byte, +-- 2-byte, and 3-byte formats of standard UTF-8 are used. The Java +-- Virtual Machine does not recognize the four-byte format of standard +-- UTF-8; it uses its own two-times-three-byte format instead." +parseJavaString :: [Word8] -> Maybe String +parseJavaString [] = Just [] +parseJavaString (x : rest) + | (x .&. 0x80) == 0 = (:) (chr (fromIntegral x)) <$> parseJavaString rest +parseJavaString (x : y : rest) | (x .&. 0xE0) == 0xC0 && ((y .&. 0xC0) == 0x80) - = chr i : getJavaString rest + = (:) (chr i) <$> parseJavaString rest where i = (fromIntegral x .&. 0x1F) `shift` 6 + (fromIntegral y .&. 0x3F) -getJavaString (x : y : z : rest) +parseJavaString (x : y : z : rest) | (x .&. 0xF0) == 0xE0 && ((y .&. 0xC0) == 0x80) && ((z .&. 0xC0) == 0x80) - = chr i : getJavaString rest + = (:) (chr i) <$> parseJavaString rest where i = ((fromIntegral x .&. 0x0F) `shift` 12 + (fromIntegral y .&. 0x3F) `shift` 6 + (fromIntegral z .&. 0x3F)) -getJavaString _ = error "internal: unable to parse byte array for Java string" +parseJavaString _ = Nothing getConstantPoolInfo :: Get [ConstantPoolInfo] getConstantPoolInfo = do @@ -255,7 +297,9 @@ getConstantPoolInfo = do case tag of -- CONSTANT_Utf8 1 -> do bytes <- replicateN getWord8 =<< getWord16be - return [Utf8 $ getJavaString bytes] + case parseJavaString bytes of + Nothing -> failure "unable to parse byte array for Java string" + Just s -> return [Utf8 s] ---- CONSTANT_Integer 3 -> do val <- get return [ConstantInteger val] @@ -301,8 +345,7 @@ getConstantPoolInfo = do 18 -> do bootstrapMethodIndex <- getWord16be nameTypeIndex <- getWord16be return [InvokeDynamic bootstrapMethodIndex nameTypeIndex] - _ -> do position <- bytesRead - error ("Unexpected constant " ++ show tag ++ " at position " ++ show position) + _ -> do failure ("Unexpected constant " ++ show tag) type ConstantPoolIndex = Word16 type ConstantPool = Array ConstantPoolIndex ConstantPoolInfo @@ -318,84 +361,105 @@ getConstantPool = do info <- getConstantPoolInfo parseList (n - fromIntegral (length info)) (info ++ result) --- | Returns string at given index in constant pool or raises error --- | if constant pool index is not a Utf8 string. -poolUtf8 :: ConstantPool -> ConstantPoolIndex -> String +-- | Return the string at the given index in the constant pool, or +-- fail if the constant pool index is not a UTF-8 string. +poolUtf8 :: ConstantPool -> ConstantPoolIndex -> Get String poolUtf8 cp i = case cp ! i of - Utf8 s -> s - v -> error $ "Index " ++ show i ++ " has value " ++ show v ++ " when string expected." + Utf8 s -> pure s + v -> failure $ "Index " ++ show i ++ " has value " ++ show v ++ " when string expected." -- | Returns value at given index in constant pool or raises error --- | if constant pool index is not a value. -poolValue :: ConstantPool -> ConstantPoolIndex -> ConstantPoolValue +-- if constant pool index is not a value. +poolValue :: ConstantPool -> ConstantPoolIndex -> Get ConstantPoolValue poolValue cp i = case cp ! i of - ConstantClass j -> ClassRef (mkClassName (cp `poolUtf8` j)) - ConstantDouble v -> Double v - ConstantFloat v -> Float v - ConstantInteger v -> Integer v - ConstantLong v -> Long v - ConstantString j -> String (cp `poolUtf8` j) - v -> error ("Index " ++ show i ++ " has unexpected value " ++ show v - ++ " when a constant was expected.") - -poolClassType :: ConstantPool -> ConstantPoolIndex -> Type -poolClassType cp i - = case cp ! i of - ConstantClass j -> - let typeName = poolUtf8 cp j - in if head typeName == '[' - then fst (parseTypeDescriptor typeName) - else ClassType (mkClassName typeName) - _ -> error ("Index " ++ show i ++ " is not a class reference.") - -poolNameAndType :: ConstantPool -> ConstantPoolIndex -> (String, String) -poolNameAndType cp i - = case cp ! i of - NameAndType nameIndex typeIndex -> - (poolUtf8 cp nameIndex, poolUtf8 cp typeIndex) - _ -> error ("Index " ++ show i ++ " is not a name and type reference.") + ConstantClass j -> ClassRef . mkClassName <$> poolUtf8 cp j + ConstantDouble v -> pure $ Double v + ConstantFloat v -> pure $ Float v + ConstantInteger v -> pure $ Integer v + ConstantLong v -> pure $ Long v + ConstantString j -> String <$> poolUtf8 cp j + v -> failure ("Index " ++ show i ++ " has unexpected value " ++ show v + ++ " when a constant was expected.") + +parseType :: String -> Get Type +parseType s = + case parseTypeDescriptor s of + Just (tp, []) -> pure tp + _ -> failure ("Invalid type descriptor: " ++ show s) + +-- | For instructions that are described in the JVM spec like this: +-- "The run-time constant pool item at the index must be a symbolic +-- reference to a class, array, or interface type." +poolClassType :: ConstantPool -> ConstantPoolIndex -> Get Type +poolClassType cp i = + case cp ! i of + ConstantClass j -> + do typeName <- poolUtf8 cp j + if head typeName == '[' + then parseType typeName + else pure $ ClassType (mkClassName typeName) + _ -> + failure ("Index " ++ show i ++ " is not a class reference.") + +poolClassName :: ConstantPool -> ConstantPoolIndex -> Get ClassName +poolClassName cp i = + case cp ! i of + ConstantClass j -> + do typeName <- poolUtf8 cp j + when (head typeName == '[') $ + failure ("Index " ++ show i ++ " is an array type and not a class.") + pure $ mkClassName typeName + _ -> + failure ("Index " ++ show i ++ " is not a class reference.") + +poolNameAndType :: ConstantPool -> ConstantPoolIndex -> Get (String, String) +poolNameAndType cp i = + case cp ! i of + NameAndType nameIndex typeIndex -> + (,) <$> poolUtf8 cp nameIndex <*> poolUtf8 cp typeIndex + _ -> failure ("Index " ++ show i ++ " is not a name and type reference.") -- | Returns tuple containing field class, name, and type at given index. -poolFieldRef :: ConstantPool -> ConstantPoolIndex -> FieldId -poolFieldRef cp i - = case cp ! i of - FieldRef classIndex ntIndex -> - let (name, fldDescriptor) = poolNameAndType cp ntIndex - (fldType, []) = parseTypeDescriptor fldDescriptor - ClassType cName = poolClassType cp classIndex - in FieldId cName name fldType - _ -> error ("Index " ++ show i ++ " is not a field reference.") - -poolInterfaceMethodRef :: ConstantPool -> ConstantPoolIndex -> (Type, MethodKey) -poolInterfaceMethodRef cp i - = case cp ! i of - InterfaceMethodRef classIndex ntIndex -> - poolTypeAndMethodKey cp classIndex ntIndex - _ -> error ("Index " ++ show i ++ " is not an interface method reference.") - -poolMethodRef :: ConstantPool -> ConstantPoolIndex -> (Type, MethodKey) -poolMethodRef cp i - = case cp ! i of - MethodRef classIndex ntIndex -> - poolTypeAndMethodKey cp classIndex ntIndex - _ -> error ("Index " ++ show i ++ " is not a method reference.") - -poolMethodOrInterfaceRef :: ConstantPool -> ConstantPoolIndex -> (Type, MethodKey) -poolMethodOrInterfaceRef cp i - = case cp ! i of - MethodRef classIndex ntIndex -> - poolTypeAndMethodKey cp classIndex ntIndex - InterfaceMethodRef classIndex ntIndex -> - poolTypeAndMethodKey cp classIndex ntIndex - _ -> error ("Index " ++ show i ++ " is not a method or interface method reference.") - -poolTypeAndMethodKey :: ConstantPool -> ConstantPoolIndex -> ConstantPoolIndex -> (Type, MethodKey) +poolFieldRef :: ConstantPool -> ConstantPoolIndex -> Get FieldId +poolFieldRef cp i = + case cp ! i of + FieldRef classIndex ntIndex -> + do (name, descriptor) <- poolNameAndType cp ntIndex + fldType <- parseType descriptor + cName <- poolClassName cp classIndex + pure $ FieldId cName name fldType + _ -> failure ("Index " ++ show i ++ " is not a field reference.") + +poolInterfaceMethodRef :: ConstantPool -> ConstantPoolIndex -> Get (Type, MethodKey) +poolInterfaceMethodRef cp i = + case cp ! i of + InterfaceMethodRef classIndex ntIndex -> + poolTypeAndMethodKey cp classIndex ntIndex + _ -> failure ("Index " ++ show i ++ " is not an interface method reference.") + +poolMethodRef :: ConstantPool -> ConstantPoolIndex -> Get (Type, MethodKey) +poolMethodRef cp i = + case cp ! i of + MethodRef classIndex ntIndex -> + poolTypeAndMethodKey cp classIndex ntIndex + _ -> failure ("Index " ++ show i ++ " is not a method reference.") + +poolMethodOrInterfaceRef :: ConstantPool -> ConstantPoolIndex -> Get (Type, MethodKey) +poolMethodOrInterfaceRef cp i = + case cp ! i of + MethodRef classIndex ntIndex -> + poolTypeAndMethodKey cp classIndex ntIndex + InterfaceMethodRef classIndex ntIndex -> + poolTypeAndMethodKey cp classIndex ntIndex + _ -> failure ("Index " ++ show i ++ " is not a method or interface method reference.") + +poolTypeAndMethodKey :: ConstantPool -> ConstantPoolIndex -> ConstantPoolIndex -> Get (Type, MethodKey) poolTypeAndMethodKey cp classIndex ntIndex = - let (name, fieldDescriptor) = poolNameAndType cp ntIndex - classType = poolClassType cp classIndex - in (classType, makeMethodKey name fieldDescriptor) + do (name, fieldDescriptor) <- poolNameAndType cp ntIndex + classType <- poolClassType cp classIndex + pure (classType, makeMethodKey name fieldDescriptor) _uncurry3 :: (a -> b -> c -> d) -> (a,b,c) -> d _uncurry3 fn (a,b,c) = fn a b c @@ -422,11 +486,11 @@ getInstruction cp address = do 0x0D -> return $ Ldc $ Float 2.0 0x0E -> return $ Ldc $ Double 0.0 0x0F -> return $ Ldc $ Double 1.0 - 0x10 -> liftM (Ldc . Integer . fromIntegral) (get :: Get Int8) - 0x11 -> liftM (Ldc . Integer . fromIntegral) (get :: Get Int16) - 0x12 -> liftM (Ldc . poolValue cp . fromIntegral) getWord8 - 0x13 -> liftM (Ldc . poolValue cp) getWord16be - 0x14 -> liftM (Ldc . poolValue cp) getWord16be + 0x10 -> liftM (Ldc . Integer . fromIntegral) getInt8 + 0x11 -> liftM (Ldc . Integer . fromIntegral) getInt16be + 0x12 -> liftM Ldc $ poolValue cp =<< liftM fromIntegral getWord8 + 0x13 -> liftM Ldc $ poolValue cp =<< getWord16be + 0x14 -> liftM Ldc $ poolValue cp =<< getWord16be 0x15 -> liftM (Iload . fromIntegral) getWord8 0x16 -> liftM (Lload . fromIntegral) getWord8 0x17 -> liftM (Fload . fromIntegral) getWord8 @@ -540,7 +604,7 @@ getInstruction cp address = do 0x83 -> return Lxor 0x84 -> do index <- getWord8 - constant <- get :: Get Int8 + constant <- getInt8 return (Iinc (fromIntegral index) (fromIntegral constant)) 0x85 -> return I2l 0x86 -> return I2f @@ -562,41 +626,41 @@ getInstruction cp address = do 0x96 -> return Fcmpg 0x97 -> return Dcmpl 0x98 -> return Dcmpg - 0x99 -> return . Ifeq . (address +) . fromIntegral =<< (get :: Get Int16) - 0x9A -> return . Ifne . (address +) . fromIntegral =<< (get :: Get Int16) - 0x9B -> return . Iflt . (address +) . fromIntegral =<< (get :: Get Int16) - 0x9C -> return . Ifge . (address +) . fromIntegral =<< (get :: Get Int16) - 0x9D -> return . Ifgt . (address +) . fromIntegral =<< (get :: Get Int16) - 0x9E -> return . Ifle . (address +) . fromIntegral =<< (get :: Get Int16) - 0x9F -> return . If_icmpeq . (address +) . fromIntegral =<< (get :: Get Int16) - 0xA0 -> return . If_icmpne . (address +) . fromIntegral =<< (get :: Get Int16) - 0xA1 -> return . If_icmplt . (address +) . fromIntegral =<< (get :: Get Int16) - 0xA2 -> return . If_icmpge . (address +) . fromIntegral =<< (get :: Get Int16) - 0xA3 -> return . If_icmpgt . (address +) . fromIntegral =<< (get :: Get Int16) - 0xA4 -> return . If_icmple . (address +) . fromIntegral =<< (get :: Get Int16) - 0xA5 -> return . If_acmpeq . (address +) . fromIntegral =<< (get :: Get Int16) - 0xA6 -> return . If_acmpne . (address +) . fromIntegral =<< (get :: Get Int16) - 0xA7 -> return . Goto . (address +) . fromIntegral =<< (get :: Get Int16) - 0xA8 -> return . Jsr . (address +) . fromIntegral =<< (get :: Get Int16) + 0x99 -> return . Ifeq . (address +) . fromIntegral =<< getInt16be + 0x9A -> return . Ifne . (address +) . fromIntegral =<< getInt16be + 0x9B -> return . Iflt . (address +) . fromIntegral =<< getInt16be + 0x9C -> return . Ifge . (address +) . fromIntegral =<< getInt16be + 0x9D -> return . Ifgt . (address +) . fromIntegral =<< getInt16be + 0x9E -> return . Ifle . (address +) . fromIntegral =<< getInt16be + 0x9F -> return . If_icmpeq . (address +) . fromIntegral =<< getInt16be + 0xA0 -> return . If_icmpne . (address +) . fromIntegral =<< getInt16be + 0xA1 -> return . If_icmplt . (address +) . fromIntegral =<< getInt16be + 0xA2 -> return . If_icmpge . (address +) . fromIntegral =<< getInt16be + 0xA3 -> return . If_icmpgt . (address +) . fromIntegral =<< getInt16be + 0xA4 -> return . If_icmple . (address +) . fromIntegral =<< getInt16be + 0xA5 -> return . If_acmpeq . (address +) . fromIntegral =<< getInt16be + 0xA6 -> return . If_acmpne . (address +) . fromIntegral =<< getInt16be + 0xA7 -> return . Goto . (address +) . fromIntegral =<< getInt16be + 0xA8 -> return . Jsr . (address +) . fromIntegral =<< getInt16be 0xA9 -> liftM (Ret . fromIntegral) getWord8 0xAA -> do read <- bytesRead skip $ fromIntegral $ (4 - read `mod` 4) `mod` 4 - defaultBranch <- return . (address +) . fromIntegral =<< (get :: Get Int32) - low <- get :: Get Int32 - high <- get :: Get Int32 + defaultBranch <- return . (address +) . fromIntegral =<< getInt32be + low <- getInt32be + high <- getInt32be offsets <- replicateN - (return . (address +) . fromIntegral =<< (get :: Get Int32)) + (return . (address +) . fromIntegral =<< getInt32be) (high - low + 1) return $ Tableswitch defaultBranch low high offsets 0xAB -> do read <- bytesRead skip (fromIntegral ((4 - read `mod` 4) `mod` 4)) - defaultBranch <- get :: Get Int32 - count <- get :: Get Int32 + defaultBranch <- getInt32be + count <- getInt32be pairs <- replicateM (fromIntegral count) $ do - v <- get :: Get Int32 - o <- get :: Get Int32 + v <- getInt32be + o <- getInt32be return (v, ((address +) . fromIntegral) o) return $ Lookupswitch (address + fromIntegral defaultBranch) pairs 0xAC -> return Ireturn @@ -605,52 +669,55 @@ getInstruction cp address = do 0xAF -> return Dreturn 0xB0 -> return Areturn 0xB1 -> return Return - 0xB2 -> return . Getstatic . poolFieldRef cp =<< getWord16be - 0xB3 -> return . Putstatic . poolFieldRef cp =<< getWord16be - 0xB4 -> return . Getfield . poolFieldRef cp =<< getWord16be - 0xB5 -> return . Putfield . poolFieldRef cp =<< getWord16be + 0xB2 -> Getstatic <$> (poolFieldRef cp =<< getWord16be) + 0xB3 -> Putstatic <$> (poolFieldRef cp =<< getWord16be) + 0xB4 -> Getfield <$> (poolFieldRef cp =<< getWord16be) + 0xB5 -> Putfield <$> (poolFieldRef cp =<< getWord16be) 0xB6 -> do index <- getWord16be - let (classType, key) = poolMethodRef cp index + (classType, key) <- poolMethodRef cp index return $ Invokevirtual classType key 0xB7 -> do index <- getWord16be - let (classType, key) = poolMethodOrInterfaceRef cp index + (classType, key) <- poolMethodOrInterfaceRef cp index return $ Invokespecial classType key 0xB8 -> do index <- getWord16be - let (ClassType cName, key) = poolMethodOrInterfaceRef cp index - in return $ Invokestatic cName key + (classType, key) <- poolMethodOrInterfaceRef cp index + cName <- + case classType of + ClassType cName -> pure cName + _ -> failure ("invokestatic: expected class type, found " ++ show classType) + pure $ Invokestatic cName key 0xB9 -> do index <- getWord16be _ <- getWord8 _ <- getWord8 - let (ClassType cName, key) = poolInterfaceMethodRef cp index - in return $ Invokeinterface cName key + (classType, key) <- poolInterfaceMethodRef cp index + cName <- + case classType of + ClassType cName -> pure cName + _ -> failure ("invokeinterface: expected class type, found " ++ show classType) + pure $ Invokeinterface cName key 0xBA -> do index <- getWord16be _ <- getWord8 _ <- getWord8 return $ Invokedynamic index - 0xBB -> do - index <- getWord16be - case (poolClassType cp index) of - ClassType name -> return (New name) - _ -> error "internal: unexpected pool class type" - 0xBC -> do - typeCode <- getWord8 - (return . Newarray . ArrayType) - (case typeCode of - 4 -> BooleanType - 5 -> CharType - 6 -> FloatType - 7 -> DoubleType - 8 -> ByteType - 9 -> ShortType - 10 -> IntType - 11 -> LongType - _ -> error "internal: invalid type code encountered" - ) - 0xBD -> return . Newarray . ArrayType . poolClassType cp =<< get + 0xBB -> New <$> (poolClassName cp =<< getWord16be) + 0xBC -> do typeCode <- getWord8 + elementType <- + case typeCode of + 4 -> pure BooleanType + 5 -> pure CharType + 6 -> pure FloatType + 7 -> pure DoubleType + 8 -> pure ByteType + 9 -> pure ShortType + 10 -> pure IntType + 11 -> pure LongType + _ -> failure "internal: invalid type code encountered" + pure $ Newarray (ArrayType elementType) + 0xBD -> Newarray . ArrayType <$> (poolClassType cp =<< getWord16be) 0xBE -> return Arraylength 0xBF -> return Athrow - 0xC0 -> return . Checkcast . poolClassType cp =<< get - 0xC1 -> return . Instanceof . poolClassType cp =<< get + 0xC0 -> Checkcast <$> (poolClassType cp =<< getWord16be) + 0xC1 -> Instanceof <$> (poolClassType cp =<< getWord16be) 0xC2 -> return Monitorenter 0xC3 -> return Monitorexit -- Wide instruction @@ -667,22 +734,19 @@ getInstruction cp address = do 0x38 -> liftM Fstore getWord16be 0x39 -> liftM Dstore getWord16be 0x3A -> liftM Astore getWord16be - 0x84 -> liftM2 Iinc getWord16be (get :: Get Int16) + 0x84 -> liftM2 Iinc getWord16be getInt16be 0xA9 -> liftM Ret getWord16be _ -> do position <- bytesRead - error ("Unexpected wide op " ++ (show op) ++ " at position " ++ show (position - 2)) - 0xC5 -> do - classIndex <- getWord16be - dimensions <- getWord8 - return (Multianewarray (poolClassType cp classIndex) dimensions) - 0xC6 -> return . Ifnull . (address +) . fromIntegral =<< (get :: Get Int16) - 0xC7 -> return . Ifnonnull . (address +) . fromIntegral =<< (get :: Get Int16) - 0xC8 -> return . Goto . (address +) . fromIntegral =<< (get :: Get Int32) - 0xC9 -> return . Jsr . (address +) . fromIntegral =<< (get :: Get Int32) + failure ("Unexpected wide op " ++ (show op) ++ " at position " ++ show (position - 2)) + 0xC5 -> Multianewarray <$> (poolClassType cp =<< getWord16be) <*> getWord8 + 0xC6 -> return . Ifnull . (address +) . fromIntegral =<< getInt16be + 0xC7 -> return . Ifnonnull . (address +) . fromIntegral =<< getInt16be + 0xC8 -> return . Goto . (address +) . fromIntegral =<< getInt32be + 0xC9 -> return . Jsr . (address +) . fromIntegral =<< getInt32be _ -> do position <- bytesRead - error ("Unexpected op " ++ (show op) ++ " at position " ++ show (position - 1)) + failure ("Unexpected op " ++ (show op) ++ " at position " ++ show (position - 1)) ---------------------------------------------------------------------- -- Attributes @@ -708,14 +772,14 @@ splitAttributes cp names = do impl n values rest = do nameIndex <- getWord16be len <- getWord32be - let name = (poolUtf8 cp nameIndex) - in case elemIndex name names of - Just i -> do - bytes <- getLazyByteString (fromIntegral len) - impl (n-1) (appendAt values i bytes) rest - Nothing -> do - bytes <- getByteString (fromIntegral len) - impl (n-1) values (Attribute name bytes : rest) + name <- poolUtf8 cp nameIndex + case elemIndex name names of + Just i -> + do bytes <- getLazyByteString (fromIntegral len) + impl (n - 1) (appendAt values i bytes) rest + Nothing -> + do bytes <- getByteString (fromIntegral len) + impl (n - 1) values (Attribute name bytes : rest) ---------------------------------------------------------------------- -- Field declarations @@ -784,19 +848,31 @@ data Field = Field { getField :: ConstantPool -> Get Field getField cp = do accessFlags <- getWord16be - name <- return . poolUtf8 cp =<< getWord16be - fldType <- return . fst . parseTypeDescriptor . poolUtf8 cp =<< getWord16be + name <- poolUtf8 cp =<< getWord16be + fldType <- parseType =<< poolUtf8 cp =<< getWord16be ([constantValue, synthetic, deprecated, signature], userAttrs) <- splitAttributes cp ["ConstantValue", "Synthetic", "Deprecated", "Signature"] + constantVal <- + case constantValue of + [bytes] -> Just <$> (poolValue cp =<< subParser getWord16be bytes) + [] -> pure Nothing + _ -> failure "internal: unexpected constant value form" + sig <- + case signature of + [bytes] -> Just <$> (poolUtf8 cp =<< subParser getWord16be bytes) + [] -> pure Nothing + _ -> failure "internal: unexpected signature form" + visibility <- + case accessFlags .&. 0x7 of + 0x0 -> pure Default + 0x1 -> pure Public + 0x2 -> pure Private + 0x4 -> pure Protected + flags -> failure $ "Unexpected flags " ++ show flags return $ Field name fldType -- Visibility - (case accessFlags .&. 0x7 of - 0x0 -> Default - 0x1 -> Public - 0x2 -> Private - 0x4 -> Protected - flags -> error $ "Unexpected flags " ++ show flags) + visibility -- Static ((accessFlags .&. 0x0008) /= 0) -- Final @@ -806,11 +882,7 @@ getField cp = do -- Transient ((accessFlags .&. 0x0080) /= 0) -- Constant Value - (case constantValue of - [bytes] -> Just $ poolValue cp $ runGet getWord16be bytes - [] -> Nothing - _ -> error "internal: unexpected constant value form" - ) + constantVal -- Check for synthetic bit in flags and buffer ((accessFlags .&. 0x1000) /= 0 || (not (null synthetic))) -- Deprecated flag @@ -818,12 +890,7 @@ getField cp = do -- Check for enum bit in flags ((accessFlags .&. 0x4000) /= 0) -- Signature - (case signature of - [bytes] -> - Just $ poolUtf8 cp $ runGet getWord16be bytes - [] -> Nothing - _ -> error "internal: unexpected signature form" - ) + sig userAttrs ---------------------------------------------------------------------- @@ -834,13 +901,15 @@ getExceptionTableEntry cp = do startPc' <- getWord16be endPc' <- getWord16be handlerPc' <- getWord16be - catchType' <- getWord16be + catchIndex <- getWord16be + catchType' <- + if catchIndex == 0 + then pure Nothing + else Just <$> poolClassType cp catchIndex return (ExceptionTableEntry startPc' endPc' handlerPc' - (if catchType' == 0 - then Nothing - else Just (poolClassType cp catchType'))) + catchType') -- Run Get Monad until end of string is reached and return list of results. getInstructions :: ConstantPool -> PC -> Get InstructionStream @@ -881,13 +950,12 @@ getLineNumberTableEntries = do lineNumber <- getWord16be return (startPc', lineNumber)) - -parseLineNumberTable :: [L.ByteString] -> LineNumberTable +parseLineNumberTable :: [L.ByteString] -> Get LineNumberTable parseLineNumberTable buffers = - let l = concatMap (runGet getLineNumberTableEntries) buffers - in LNT { pcLineMap = Map.fromList l - , linePCMap = Map.fromListWith min [ (ln,pc) | (pc,ln) <- l ] - } + do l <- concat <$> traverse (subParser getLineNumberTableEntries) buffers + pure LNT { pcLineMap = Map.fromList l + , linePCMap = Map.fromListWith min [ (ln,pc) | (pc,ln) <- l ] + } ---------------------------------------------------------------------- -- LocalVariableTableEntry @@ -911,19 +979,14 @@ getLocalVariableTableEntries cp = do replicateM (fromIntegral tableLength) (do startPc' <- getWord16be len <- getWord16be - nameIndex <- getWord16be - descriptorIndex <- getWord16be + name <- getWord16be >>= poolUtf8 cp + ty <- getWord16be >>= poolUtf8 cp >>= parseType index <- getWord16be - return $ LocalVariableTableEntry - startPc' - len - (poolUtf8 cp nameIndex) - (fst $ parseTypeDescriptor $ poolUtf8 cp descriptorIndex) - index) - -parseLocalVariableTable :: ConstantPool -> [L.ByteString] -> [LocalVariableTableEntry] + pure $ LocalVariableTableEntry startPc' len name ty index) + +parseLocalVariableTable :: ConstantPool -> [L.ByteString] -> Get [LocalVariableTableEntry] parseLocalVariableTable cp buffers = - (concat $ map (runGet $ getLocalVariableTableEntries cp) buffers) + concat <$> traverse (subParser (getLocalVariableTableEntries cp)) buffers ---------------------------------------------------------------------- -- Method body @@ -949,12 +1012,14 @@ getCode cp = do exceptionTable <- getWord16be >>= replicateN (getExceptionTableEntry cp) ([lineNumberTables, localVariableTables], userAttrs) <- splitAttributes cp ["LineNumberTable", "LocalVariableTable"] + lnt <- parseLineNumberTable lineNumberTables + lvt <- parseLocalVariableTable cp localVariableTables return $ Code maxStack maxLocals (buildCFG exceptionTable instructions) exceptionTable - (parseLineNumberTable lineNumberTables) - (parseLocalVariableTable cp localVariableTables) + lnt + lvt userAttrs ---------------------------------------------------------------------- @@ -1027,45 +1092,49 @@ instance Ord Method where getExceptions :: ConstantPool -> Get [Type] getExceptions cp = do exceptionCount <- getWord16be - replicateN (getWord16be >>= return . poolClassType cp) exceptionCount + replicateN (getWord16be >>= poolClassType cp) exceptionCount getMethod :: ConstantPool -> Get Method getMethod cp = do accessFlags <- getWord16be - name <- getWord16be >>= return . (poolUtf8 cp) - (returnType, parameterTypes) <- getWord16be >>= return . parseMethodDescriptor . (poolUtf8 cp) + name <- getWord16be >>= poolUtf8 cp + descriptor <- getWord16be >>= poolUtf8 cp + (returnType, parameterTypes) <- + maybe (failure "Invalid method descriptor") pure $ parseMethodDescriptor descriptor ([codeVal, exceptionsVal, syntheticVal, deprecatedVal], userAttrs) <- splitAttributes cp ["Code", "Exceptions", "Synthetic", "Deprecated"] + visibility <- + case accessFlags .&. 0x7 of + 0x0 -> pure Default + 0x1 -> pure Public + 0x2 -> pure Private + 0x4 -> pure Protected + flags -> failure $ "Unexpected flags " ++ show flags let isStatic' = (accessFlags .&. 0x008) /= 0 isFinal = (accessFlags .&. 0x010) /= 0 isSynchronized' = (accessFlags .&. 0x020) /= 0 isAbstract = (accessFlags .&. 0x400) /= 0 isStrictFp' = (accessFlags .&. 0x800) /= 0 - in return $ + body <- + if ((accessFlags .&. 0x100) /= 0) then pure NativeMethod else + if isAbstract then pure AbstractMethod else + case codeVal of + [bytes] -> subParser (getCode cp) bytes + _ -> failure "Could not find code attribute" + exceptions <- + case exceptionsVal of + [bytes] -> Just <$> subParser (getExceptions cp) bytes + [] -> pure Nothing + _ -> failure "internal: unexpected expectionsVal form" + return $ Method (MethodKey name parameterTypes returnType) - -- Visibility - (case accessFlags .&. 0x7 of - 0x0 -> Default - 0x1 -> Public - 0x2 -> Private - 0x4 -> Protected - flags -> error $ "Unexpected flags " ++ show flags) + visibility isStatic' isFinal isSynchronized' isStrictFp' - (if ((accessFlags .&. 0x100) /= 0) - then NativeMethod - else if isAbstract - then AbstractMethod - else case codeVal of - [bytes] -> runGet (getCode cp) bytes - _ -> error "Could not find code attribute") - (case exceptionsVal of - [bytes] -> Just (runGet (getExceptions cp) bytes) - [] -> Nothing - _ -> error "internal: unexpected expectionsVal form" - ) + body + exceptions (not $ null syntheticVal) (not $ null deprecatedVal) userAttrs @@ -1284,18 +1353,25 @@ getClass :: Get Class getClass = do magic <- getWord32be (if magic /= 0xCAFEBABE - then error "Unexpected magic value" + then failure "Unexpected magic value" else return ()) minorVersion' <- getWord16be majorVersion' <- getWord16be cp <- getConstantPool accessFlags <- getWord16be - thisClass <- getReferenceName cp + thisClass <- getWord16be >>= poolClassName cp superClassIndex <- getWord16be - interfaces <- getWord16be >>= replicateN (getReferenceName cp) + superClass' <- if superClassIndex == 0 then pure Nothing else + Just <$> poolClassName cp superClassIndex + interfaces <- getWord16be >>= replicateN (getWord16be >>= poolClassName cp) fields <- getWord16be >>= replicateN (getField cp) methods <- getWord16be >>= replicateN (getMethod cp) ([sourceFile], userAttrs) <- splitAttributes cp ["SourceFile"] + sourceFile' <- + case sourceFile of + [bytes] -> Just <$> (poolUtf8 cp =<< subParser getWord16be bytes) + [] -> pure Nothing + _ -> failure "internal: unexpected source file form" return $ MkClass majorVersion' minorVersion' cp @@ -1305,28 +1381,12 @@ getClass = do ((accessFlags .&. 0x200) /= 0) ((accessFlags .&. 0x400) /= 0) thisClass - (if superClassIndex == 0 - then Nothing - else - case poolClassType cp superClassIndex of - ClassType name -> (Just name) - classType -> error ("Unexpected class type " ++ show classType)) + superClass' interfaces fields (Map.fromList (map (\m -> (methodKey m, m)) methods)) - -- Source file - (case sourceFile of - [bytes] -> - Just $ poolUtf8 cp $ runGet getWord16be bytes - [] -> Nothing - _ -> error "internal: unexpected source file form" - ) + sourceFile' userAttrs - where getReferenceName cp = do - index <- getWord16be - case poolClassType cp index of - ClassType name -> return name - tp -> error ("Unexpected class type " ++ show tp) -- | Returns method with given key in class or 'Nothing' if no method with that -- key is found. @@ -1338,8 +1398,9 @@ loadClass :: FilePath -> IO Class loadClass path = do handle <- openBinaryFile path ReadMode contents <- L.hGetContents handle - let result = runGet getClass contents - in result `seq` (hClose handle >> return result) + case runGetOrFail getClass contents of + Left (_, pos, msg) -> fail $ "loadClass: parse failure at offset " ++ show pos ++ ": " ++ msg + Right (_, _, result) -> result `seq` (hClose handle >> pure result) getElemTy :: Type -> Type getElemTy (ArrayType t) = aux t