diff --git a/.ghci b/.ghci index de6f30c..756ad40 100644 --- a/.ghci +++ b/.ghci @@ -7,3 +7,5 @@ :def! unimath (\_ -> return ":r\n :set args --no-terms -r -v agda2train:10 -ojson/ -i $HOME/agda-unimath/src $HOME/agda-unimath/src/everything.lagda.md\n main") :def! typetopo (\_ -> return ":r\n :set args --no-terms -r -v agda2train:10 -ojson/ -i $HOME/TypeTopology/source $HOME/TypeTopology/source/index.lagda\n main") :def! prelude (\_ -> return ":r\n :set args -r -v agda2train:10 -ojson/ -i $HOME/git/formal-prelude/ $HOME/git/formal-prelude/Prelude/Main.agda\n main") +:def! testDB (\x -> return $ ":r\n :set args add dist/db.json test/golden/" <> x <> ".json\n main") +:def! testQueryDB (\x -> return $ ":r\n :set args query dist/db.json " <> x <> ".json\n main") diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9387c95..6a45d3e 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -5,7 +5,7 @@ on: # schedule: [{cron: '0 0 * * *'}] pull_request: push: - branches: [master] + branches: [master, db-example] jobs: cabal: @@ -48,6 +48,7 @@ jobs: # fingerprint doesn't match, so we don't need to start from scratch every # time a dependency changes. - uses: actions/cache@v3 + if: github.ref == 'refs/heads/master' name: Cache ~/.cabal/store with: path: ${{ steps.setup-haskell-cabal.outputs.cabal-store }} @@ -55,7 +56,9 @@ jobs: restore-keys: ${{ runner.os }}-${{ matrix.ghc }}- - name: Test - run: make test + run: | + make test + make testDB # - name: Free disk space # if: ${{ matrix.stdlib }} diff --git a/Makefile b/Makefile index f419289..935f554 100644 --- a/Makefile +++ b/Makefile @@ -5,6 +5,9 @@ default: repl repl: cabal repl agda2train # e.g. `:set args -r -o json -itest test/First.agda ... main ... :r ... main` +replDB: + cabal repl agda2train-db + build: cabal build @@ -27,6 +30,13 @@ cleanTest: make -C test clean make -C test cleanGolden +# DB example + +testDB: + cabal run agda2train-db -- add data/db.json test/golden/Test.PiVsFun.json + cabal run agda2train-db -- query data/db.json data/query.json \ + | tail -n1 | cmp data/expected_response.txt + # Extracting training data from whole libraries STDLIB?=$(HOME)/git/agda-stdlib diff --git a/agda2train.cabal b/agda2train.cabal index c4bb3d7..2268e16 100644 --- a/agda2train.cabal +++ b/agda2train.cabal @@ -48,6 +48,8 @@ common globalOptions , mtl >=2.2.1 && <2.4 , async >=2.2 && <2.3 , file-embed == 0.0.15.0 + , aeson-pretty >= 0.8.9 && < 0.9 + , bytestring >=0.10.8.1 && <0.13 library agda2train-lib import: globalOptions @@ -65,8 +67,17 @@ executable agda2train build-depends: agda2train-lib , deepseq >=1.4.2.0 && <1.6 - , bytestring >=0.10.8.1 && <0.13 , directory >=1.2.6.2 && <1.4 , filepath >=1.4.1.0 && <1.5 , unordered-containers >=0.2.9.0 && <0.3 - , aeson-pretty >= 0.8.9 && < 0.9 + +executable agda2train-db + import: globalOptions + main-is: DB.hs + ghc-options: + -threaded -rtsopts -with-rtsopts=-N + -Wno-missing-home-modules + build-depends: + agda2train-lib + , unordered-containers >=0.2.9.0 && <0.3 + , hashable diff --git a/data/db.json b/data/db.json new file mode 100644 index 0000000..67261e2 --- /dev/null +++ b/data/db.json @@ -0,0 +1,11 @@ +{ + "-5242397938801254731": [ + "Test.PiVsFun.refl<10>" + ], + "4884008893817283726": [ + "Test.PiVsFun.refl<10>" + ], + "7116324300784873135": [ + "Test.PiVsFun.refl<10>" + ] +} \ No newline at end of file diff --git a/data/expected_response.txt b/data/expected_response.txt new file mode 100644 index 0000000..4f6bf96 --- /dev/null +++ b/data/expected_response.txt @@ -0,0 +1 @@ +["Test.PiVsFun.refl<10>"] diff --git a/data/query.json b/data/query.json new file mode 100644 index 0000000..918b11c --- /dev/null +++ b/data/query.json @@ -0,0 +1,72 @@ +{ + "scope": [ + { + "name": "Test.PiVsFun.refl<10>", + "type": { + "tag": "Pi", + "name": "x", + "domain": { + "tag": "Sort", + "sort": "Set" + }, + "codomain": { + "tag": "Pi", + "name": "y", + "domain": { + "tag": "Sort", + "sort": "Set" + }, + "codomain": { + "tag": "Application", + "head": { + "tag": "ScopeReference", + "name": "Test.PiVsFun._≡_<4>" + }, + "arguments": [ + { + "tag": "DeBruijn", + "index": 1 + }, + { + "tag": "DeBruijn", + "index": 0 + } + ] + } + } + }, + "definition": { + "tag": "Postulate" + } + } + ], + "context": [ + { + "tag": "Sort", + "name": "x", + "sort": "Set" + }, + { + "tag": "Sort", + "name": "x", + "sort": "Set" + } + ], + "goal": { + "tag": "Application", + "head": { + "tag": "ScopeReference", + "name": "Test.PiVsFun._≡_<4>" + }, + "arguments": [ + { + "tag": "DeBruijn", + "index": 1 + }, + { + "tag": "DeBruijn", + "index": 1 + } + ] + } +} diff --git a/src/DB.hs b/src/DB.hs new file mode 100644 index 0000000..d527793 --- /dev/null +++ b/src/DB.hs @@ -0,0 +1,110 @@ +{-# LANGUAGE TypeApplications #-} +-- | Store the JSON information in an SQL database for quick lookups. +-- +-- NB: "training" this model by feeding samples merely consists of adding DB rows +module Main where + +import GHC.Generics + +import System.Environment ( getArgs ) +import qualified Data.HashMap.Strict as M +import Data.Maybe ( fromMaybe ) +import Control.Monad ( forM_ ) +import Control.Arrow ( first ) + +import Data.Aeson + ( ToJSON(..), genericToJSON + , FromJSON(..), genericParseJSON + , eitherDecodeFileStrict' ) + +import Data.Hashable ( Hashable, hash ) + +-- import Database.Persist.TH + +import Agda.Utils.Either ( caseEitherM ) + +import ToTrain ( names ) +import Output hiding ( ScopeEntry, ScopeEntry' ) +import qualified Output as O + +data ScopeEntry' = ScopeEntry + { _type :: Type + , definition :: Definition + } deriving (Generic, Show, Eq, Hashable) +instance ToJSON ScopeEntry' where toJSON = genericToJSON jsonOpts +instance FromJSON ScopeEntry' where parseJSON = genericParseJSON jsonOpts +type ScopeEntry = Named ScopeEntry' +data Context = Context + { scope :: [ScopeEntry] + , context :: [Type] + , goal :: Type + } deriving (Generic, Show, Eq, Hashable) +instance ToJSON Context where toJSON = genericToJSON jsonOpts +instance FromJSON Context where parseJSON = genericParseJSON jsonOpts +type Premises = [String] + +type Key' = Context +type Key = Int -- hash of `Context` +type Value = Premises +type Database = M.HashMap Key Value + +instance Hashable a => Hashable (Named a) +instance Hashable a => Hashable (Pretty a) +instance Hashable a => Hashable (Reduced a) +instance Hashable O.ScopeEntry'; instance Hashable Sample +instance Hashable Definition; instance Hashable Clause; instance Hashable Term + +fileSamples :: FileData -> [(Key', Value)] +fileSamples (Named{item = TrainData{..}}) = + concatMap (map go . fromMaybe [] . holes . item) scopeLocal + where + usedIn :: Name -> Term -> Bool + usedIn n = \case + Pi _ (n' :~ ty) t -> n `usedIn` ty || ((n' /= n) && n `usedIn` t) + Lam (n' :~ t) -> (n' /= n) && n `usedIn` t + App hd ts -> case hd of {Ref n' -> n' == n; DB _ -> False} + || any (n `usedIn`) ts + _ -> False + + go :: Sample -> (Key', Value) + go Sample{..} = Context + { scope = fmap bareScope + <$> scopeGlobal + <> onlyRelevant scopeLocal + <> onlyRelevant (fromMaybe [] scopePrivate) + , context = thing . item <$> thing (ctx) + , goal = original (thing goal) + } .-> premises + where (.->) = (,) + onlyRelevant = filter $ (`usedIn` original (thing term)) . name + bareScope :: O.ScopeEntry' -> ScopeEntry' + bareScope O.ScopeEntry{..} = ScopeEntry {_type = original (thing _type), definition = thing definition} + +main :: IO () +main = getArgs >>= \case + ("add" : dbJson : jsonFns) -> forM_ jsonFns $ \jsonFn -> do + caseEitherM (eitherDecodeFileStrict' jsonFn) fail $ \fd -> do + -- putStrLn "** File Data" >> print fd + let samples = fileSamples fd + putStrLn "** Samples" -- >> print samples + forM_ samples $ \(k, _) -> do + putStrLn "------" >> print k + putStrLn "# " >> print (hash k) + let db = M.fromList (first hash <$> samples) + putStrLn "** Database" >> print db + putStrLn "serializing database into a .json file" + encodeFile dbJson db + ("query" : dbJson : ctxJson : []) -> + caseEitherM (eitherDecodeFileStrict' @Database dbJson) fail $ \db -> do + putStrLn "** Database" >> print db + caseEitherM (eitherDecodeFileStrict' @Context ctxJson) fail $ \ctx -> do + putStrLn "** Context" >> print ctx + let hctx = hash ctx + putStrLn "# " >> print hctx + case M.lookup hctx db of + Just premises -> putStrLn "** Premises" >> print premises + Nothing -> putStrLn "NOT TRAINED ON THIS CONTEXT!" + args -> fail + $ "Usage: agda2train-db add * \n" + <> " or agda2train-db query \n\n" + <> "User entered: agda2train-db " <> unwords args diff --git a/src/Main.hs b/src/Main.hs index 6c51b4a..bc72aa3 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -13,9 +13,6 @@ import qualified Data.List as L import qualified Data.Set as S import qualified Data.HashMap.Strict as HM import qualified Data.ByteString.Lazy as BL -import Data.Aeson ( ToJSON ) -import Data.Aeson.Encode.Pretty - ( encodePretty', Config(..), defConfig, Indent(..), keyOrder ) import Control.Monad import Control.Monad.IO.Class ( liftIO ) @@ -205,30 +202,3 @@ getOutDir opts = case outDir opts of getOutFn :: Options -> String -> FilePath getOutFn opts mn = getOutDir opts mn <> ".json" - --- * JSON encoding - --- | Uses "Aeson.Pretty" to order the JSON fields. -encode :: ToJSON a => a -> BL.ByteString -encode = encodePretty' $ defConfig - { confIndent = Spaces 2 - , confCompare = keyOrder - [ "pretty" - , "tag" - , "name" - , "original", "simplified", "reduced", "normalised" - , "telescope", "patterns", "fields" - , "domain", "codomain" - , "abstraction", "body" - , "sort", "level", "literal" - , "head", "arguments" - , "variants", "reference", "variant" - , "index" - , "scopeGlobal", "scopeLocal" - , "type", "definition", "holes" - , "ctx", "goal", "term", "premises" - ] - } - -encodeFile :: ToJSON a => FilePath -> a -> IO () -encodeFile = \fn -> BL.writeFile fn . encode diff --git a/src/Output.hs b/src/Output.hs index c82c9c7..8a3b67e 100644 --- a/src/Output.hs +++ b/src/Output.hs @@ -3,14 +3,20 @@ -- internal Agda definition to this format. module Output where +import GHC.Generics ( Generic ) + import Control.Arrow ( second ) import Control.Applicative ( (<|>), liftA2 ) -import GHC.Generics ( Generic ) + import Data.List ( notElem, elemIndex ) import Data.String ( fromString ) -import Data.Aeson + +import qualified Data.ByteString.Lazy as BL +import Data.Aeson hiding ( encode ) import qualified Data.Aeson as JSON import qualified Data.Aeson.KeyMap as KM +import Data.Aeson.Encode.Pretty + ( encodePretty', Config(..), defConfig, Indent(..), keyOrder ) import Agda.Syntax.Common ( unArg ) import qualified Agda.Syntax.Common as A @@ -40,6 +46,7 @@ type DB = Int -- | A head of a λ-application can either be a defined name in the global scope, -- or a DeBruijn index into the local context. type Head = Either Name DB +pattern Ref x = Left x; pattern DB x = Right x -- * Generic constructions @@ -51,18 +58,17 @@ infixr 4 :>; pattern x :> y = Pretty {pretty = x, thing = y} data Pretty a = Pretty { pretty :: String , thing :: a - } deriving Generic -deriving instance Show a => Show (Pretty a) + } deriving (Generic, Show, Eq) instance ToJSON a => ToJSON (Pretty a) where toJSON (Pretty{..}) = let pretty' = toJSON pretty in case toJSON thing of (Object fs) -> object ("pretty" .= pretty' : KM.toList fs) t@(Array xs) -> object ["pretty" .= pretty', "telescope" .= t] - t -> object ["pretty" .= pretty', "thing" .= toJSON t] + t -> object ["pretty" .= pretty', "thing" .= t] instance FromJSON a => FromJSON (Pretty a) where parseJSON = withObject "Pretty" $ \v -> Pretty <$> v .: "pretty" - <*> (v .: "thing" <|> parseJSON (Object v)) + <*> (v .: "telescope" <|> v .: "thing" <|> parseJSON (Object v)) -- | Bundle a term with (several of) its normalised forms. -- @@ -76,8 +82,7 @@ data Reduced a = Reduced , simplified :: Maybe a , reduced :: Maybe a , normalised :: Maybe a - } deriving (Generic, Functor, Foldable, Traversable) -deriving instance Show a => Show (Reduced a) + } deriving (Generic, Show, Eq, Functor, Foldable, Traversable) instance ToJSON a => ToJSON (Reduced a) where toJSON r@(Reduced{..}) | Nothing <- simplified <|> reduced <|> normalised @@ -99,7 +104,7 @@ infixr 4 :~; pattern x :~ y = Named {name = x, item = y} data Named a = Named { name :: Name , item :: a - } deriving (Generic, Show) + } deriving (Generic, Show, Eq, Functor) instance ToJSON a => ToJSON (Named a) where toJSON (Named{..}) = let name' = toJSON name in case toJSON item of @@ -126,8 +131,8 @@ data TrainData = TrainData , scopePrivate :: Maybe [ScopeEntry] -- ^ The /private/ scope, containing private definitions not exported to the public, -- as well as system-generated definitions stemming from @where@ or @with@. - } deriving (Generic, Show) -instance ToJSON TrainData where toJSON = genericToJSON jsonOpts + } deriving (Generic, Show, Eq) +instance ToJSON TrainData where toJSON = genericToJSON jsonOpts instance FromJSON TrainData where parseJSON = genericParseJSON jsonOpts -- | Every 'ScopeEntry'' is /named/. @@ -136,11 +141,11 @@ type ScopeEntry = Named ScopeEntry' data ScopeEntry' = ScopeEntry { _type :: Pretty (Reduced Type) -- ^ The entry's type. - , definition :: Maybe (Pretty Definition) + , definition :: Pretty Definition -- ^ The actual body of this entry's definition. , holes :: Maybe [Sample] -- ^ Training data for each of the subterms in this entry's 'definition'. - } deriving (Generic, Show) + } deriving (Generic, Show, Eq) instance ToJSON ScopeEntry' where toJSON = genericToJSON jsonOpts instance FromJSON ScopeEntry' where parseJSON = genericParseJSON jsonOpts @@ -156,7 +161,7 @@ data Sample = Sample -- ^ The term that successfully fills the current 'goal'. , premises :: [Name] -- ^ Definitions used in this "proof", intended to be used for /premise selection/. - } deriving (Generic, Show, ToJSON, FromJSON) + } deriving (Generic, Show, Eq, ToJSON, FromJSON) -- | Agda definitions: datatypes, records, functions, postulates and primitives. data Definition @@ -180,7 +185,7 @@ data Definition -- ^ e.g. `postulate pred : ℕ → ℕ` | Primitive {} -- ^ e.g. `primitive primShowNat : ℕ → String` - deriving (Generic, Show, ToJSON, FromJSON) + deriving (Generic, Show, Eq, ToJSON, FromJSON) -- | Function clauses. data Clause = Clause @@ -190,8 +195,8 @@ data Clause = Clause -- ^ the actual patterns of this function clause , body :: Maybe Term -- ^ the right hand side of the clause (@Nothing@ for absurd clauses) - } deriving (Generic, Show) -instance ToJSON Clause where toJSON = genericToJSON jsonOpts + } deriving (Generic, Show, Eq) +instance ToJSON Clause where toJSON = genericToJSON jsonOpts instance FromJSON Clause where parseJSON = genericParseJSON jsonOpts -- | A telescope is a sequence of (named) types, a.k.a. bindings. @@ -211,18 +216,18 @@ data Term | Sort String -- ^ e.g. @Set@ | Level String -- ^ e.g. @0ℓ@ | UnsolvedMeta -- ^ i.e. @{!!}@ - deriving (Generic, Show) + deriving (Generic, Show, Eq) instance {-# OVERLAPPING #-} ToJSON Head where toJSON = object . \case - (Left n) -> [tag "ScopeReference", "name" .= toJSON n] - (Right i) -> [tag "DeBruijn", "index" .= toJSON i] + (Ref n) -> [tag "ScopeReference", "name" .= toJSON n] + (DB i) -> [tag "DeBruijn", "index" .= toJSON i] where tag s = "tag" .= JSON.String s instance {-# OVERLAPPING #-} FromJSON Head where parseJSON = withObject "Head" $ \o -> o .: "tag" >>= \case - String "ScopeReference" -> Left <$> o .: "name" - String "DeBruijn" -> Right <$> o .: "index" + String "ScopeReference" -> Ref <$> o .: "name" + String "DeBruijn" -> DB <$> o .: "index" tag -> fail $ "Cannot parse Head: unexpected \"tag\" field " <> show tag instance ToJSON Term where @@ -251,13 +256,13 @@ instance ToJSON Term where instance FromJSON Term where parseJSON = withObject "Term" $ \o -> o .: "tag" >>= \case - String "Pi" -> Pi undefined <$> liftA2 (:~) (o .: "name") (o .: "domain") - <*> o .: "codomain" + String "Pi" -> Pi True <$> liftA2 (:~) (o .: "name") (o .: "domain") + <*> o .: "codomain" -- T0D0: also serialise `isDep` String "Lambda" -> Lam <$> liftA2 (:~) (o .: "abstraction") (o .: "body") String "Application" -> App <$> o .: "head" <*> o .: "arguments" - String "ScopeReference" -> flip App [] . Left <$> o .: "name" - String "DeBruijn" -> flip App [] . Right <$> o .: "index" + String "ScopeReference" -> flip App [] . Ref <$> o .: "name" + String "DeBruijn" -> flip App [] . DB <$> o .: "index" String "Literal" -> Lit <$> o .: "literal" String "Sort" -> Sort <$> o .: "sort" String "Level" -> Level <$> o .: "level" @@ -314,12 +319,12 @@ instance A.Clause ~> Clause where instance A.DeBruijnPattern ~> Pattern where go = \case - A.VarP _ v -> return $ App (Right $ dbPatVarIndex v) [] + A.VarP _ v -> return $ App (DB $ dbPatVarIndex v) [] A.DotP _ t -> go t A.ConP c _ ps -> do - App (Left $ pp c) <$> traverse go (A.namedThing . unArg <$> ps) + App (Ref $ pp c) <$> traverse go (A.namedThing . unArg <$> ps) A.LitP _ lit -> return $ Lit (pp lit) - A.ProjP _ qn -> return $ App (Left $ pp qn) [] + A.ProjP _ qn -> return $ App (Ref $ pp qn) [] p@(A.IApplyP _ _ _ _) -> panic "pattern (cubical)" p p@(A.DefP _ _ _) -> panic "pattern (cubical)" p @@ -349,9 +354,9 @@ instance A.Term ~> Term where ab' <- go (unAbs ab) return $ Lam (pp (absName ab) :~ ab') -- ** applications - (A.Var i xs) -> App (Right i) <$> (traverse go xs) - (A.Def f xs) -> App (Left $ ppName f) <$> (traverse go xs) - (A.Con c _ xs) -> App (Left $ ppName $ conName c) <$> (traverse go xs) + (A.Var i xs) -> App (DB i) <$> (traverse go xs) + (A.Def f xs) -> App (Ref $ ppName f) <$> (traverse go xs) + (A.Con c _ xs) -> App (Ref $ ppName $ conName c) <$> (traverse go xs) -- ** other constants (A.Lit x) -> return $ Lit $ pp x (A.Level x) -> return $ Level $ pp x @@ -365,7 +370,7 @@ instance A.Term ~> Term where instance A.Elim ~> Term where go = \case (A.Apply x) -> go (unArg x) - (A.Proj _ qn) -> return $ App (Left $ ppName qn) [] + (A.Proj _ qn) -> return $ App (Ref $ ppName qn) [] (A.IApply _ _ x) -> go x -- * Utilities @@ -408,6 +413,8 @@ isNotCubical A.Clause{..} | otherwise = True +-- ** JSON encoding + -- | Configure JSON to omit empty (optional) fields and switch -- from camelCase to kebab-case. jsonOpts :: JSON.Options @@ -421,6 +428,33 @@ jsonOpts = defaultOptions s -> s } +-- | Uses "Aeson.Pretty" to order the JSON fields. +encode :: ToJSON a => a -> BL.ByteString +encode = encodePretty' $ defConfig + { confIndent = Spaces 2 + , confCompare = keyOrder + [ "pretty" + , "tag" + , "name" + , "original", "simplified", "reduced", "normalised" + , "telescope", "patterns", "fields" + , "domain", "codomain" + , "abstraction", "body" + , "sort", "level", "literal" + , "head", "arguments" + , "variants", "reference", "variant" + , "index" + , "scopeGlobal", "scopeLocal" + , "type", "definition", "holes" + , "ctx", "goal", "term", "premises" + ] + } + +encodeFile :: ToJSON a => FilePath -> a -> IO () +encodeFile = \fn -> BL.writeFile fn . encode + +-- + instance P.PrettyTCM A.Definition where prettyTCM d = go (theDef d) where