diff --git a/src/Nix/TH.hs b/src/Nix/TH.hs index f8368ec83..526465219 100644 --- a/src/Nix/TH.hs +++ b/src/Nix/TH.hs @@ -17,26 +17,39 @@ import Nix.Parser quoteExprExp :: String -> ExpQ quoteExprExp s = do - expr <- - either - (fail . show) - pure - (parseNixText $ toText s) + expr <- parseExpr s dataToExpQ - (const Nothing `extQ` metaExp (freeVars expr) `extQ` (pure . (TH.lift :: Text -> Q Exp))) + (extQOnFreeVars metaExp expr `extQ` (pure . (TH.lift :: Text -> Q Exp))) expr quoteExprPat :: String -> PatQ quoteExprPat s = do - expr <- - either - (fail . show) - pure - (parseNixText $ toText s) + expr <- parseExpr s dataToPatQ - (const Nothing `extQ` metaPat (freeVars expr)) + (extQOnFreeVars metaPat expr) expr +-- | Helper function. +extQOnFreeVars + :: ( Typeable b + , Typeable loc + ) + => ( Set VarName + -> loc + -> Maybe q + ) + -> NExpr + -> b + -> Maybe q +extQOnFreeVars f e = extQ (const Nothing) (f $ freeVars e) + +parseExpr :: (MonadFail m, ToText a) => a -> m NExpr +parseExpr s = + either + (fail . show) + pure + (parseNixText $ toText s) + freeVars :: NExpr -> Set VarName freeVars e = case unFix e of (NConstant _ ) -> mempty @@ -44,11 +57,11 @@ freeVars e = case unFix e of (NSym var ) -> one var (NList list ) -> mapFreeVars list (NSet NonRecursive bindings) -> bindFreeVars bindings - (NSet Recursive bindings) -> Set.difference (bindFreeVars bindings) (bindDefs bindings) + (NSet Recursive bindings) -> diffBetween bindFreeVars bindDefs bindings (NLiteralPath _ ) -> mempty (NEnvPath _ ) -> mempty (NUnary _ expr ) -> freeVars expr - (NBinary _ left right ) -> ((<>) `on` freeVars) left right + (NBinary _ left right ) -> collectFreeVars left right (NSelect expr path orExpr) -> Set.unions [ freeVars expr @@ -69,18 +82,22 @@ freeVars e = case unFix e of ) (NLet bindings expr ) -> freeVars expr <> - Set.difference - (bindFreeVars bindings) - (bindDefs bindings) + diffBetween bindFreeVars bindDefs bindings (NIf cond th el ) -> Set.unions $ freeVars <$> [cond, th, el] -- Evaluation is needed to find out whether x is a "real" free variable in `with y; x`, we just include it -- This also makes sense because its value can be overridden by `x: with y; x` - (NWith set expr ) -> ((<>) `on` freeVars) set expr - (NAssert assertion expr ) -> ((<>) `on` freeVars) assertion expr + (NWith set expr ) -> collectFreeVars set expr + (NAssert assertion expr ) -> collectFreeVars assertion expr (NSynHole _ ) -> mempty where + diffBetween :: (a -> Set VarName) -> (a -> Set VarName) -> a -> Set VarName + diffBetween g f b = Set.difference (g b) (f b) + + collectFreeVars :: NExpr -> NExpr -> Set VarName + collectFreeVars = (<>) `on` freeVars + bindDefs :: Foldable t => t (Binding NExpr) -> Set VarName bindDefs = foldMap bind1Def where diff --git a/tests/NixLanguageTests.hs b/tests/NixLanguageTests.hs index 3c974318a..847ce82d3 100644 --- a/tests/NixLanguageTests.hs +++ b/tests/NixLanguageTests.hs @@ -77,14 +77,14 @@ genTests = do <$> globDir1 (compile "*-*-*.*") "data/nix/tests/lang" let testsByName = groupBy (takeFileName . dropExtensions) testFiles let testsByType = groupBy testType (Map.toList testsByName) - let testGroups = fmap mkTestGroup (Map.toList testsByType) + let testGroups = mkTestGroup <$> Map.toList testsByType pure $ localOption (mkTimeout 2000000) $ testGroup "Nix (upstream) language tests" testGroups where testType (fullpath, _files) = take 2 $ splitOn "-" $ takeFileName fullpath mkTestGroup (kind, tests) = - testGroup (String.unwords kind) $ fmap (mkTestCase kind) tests + testGroup (String.unwords kind) $ mkTestCase kind <$> tests mkTestCase kind (basename, files) = testCase (takeFileName basename) $ do time <- liftIO getCurrentTime let opts = defaultOptions time