Skip to content

Commit

Permalink
Merge branch 'refactors'
Browse files Browse the repository at this point in the history
  • Loading branch information
khibino committed Dec 6, 2024
2 parents 83bcd5a + c262921 commit da00016
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 39 deletions.
32 changes: 30 additions & 2 deletions dnsext-iterative/DNS/Iterative/Query/API.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import DNS.Iterative.Query.Helpers
import DNS.Iterative.Query.Local (takeLocalResult)
import DNS.Iterative.Query.Resolve
import DNS.Iterative.Query.Types
import DNS.Iterative.Query.Utils (logQueryErrors)
import DNS.Iterative.Query.Utils (logLn, pprMessage)

-----

Expand Down Expand Up @@ -93,6 +93,35 @@ getResponse' name qaction liftR denied replied env reqM q@(Question bn typ cls)
reqEH = DNS.ednsHeader reqM
{- FOURMOLU_ENABLE -}

{- FOURMOLU_DISABLE -}
logQueryErrors :: String -> DNSQuery a -> DNSQuery a
logQueryErrors prefix q = do
handleQueryError left return q
where
left qe = do
logQueryError qe
throwError qe
logQueryError qe = case qe of
DnsError de ss -> logDnsError de ss
NotResponse addrs resp msg -> logNotResponse addrs resp msg
InvalidEDNS addrs eh msg -> logInvalidEDNS addrs eh msg
HasError addrs rcode msg -> logHasError addrs rcode msg
logDnsError de ss = case de of
NetworkFailure {} -> putLog detail
DecodeError {} -> putLog detail
RetryLimitExceeded -> putLog detail
UnknownDNSError {} -> putLog detail
_ -> pure ()
where detail = show de ++ ": " ++ intercalate ", " ss
logNotResponse addrs False msg = putLog $ pprAddrs addrs ++ ":\n" ++ pprMessage "not response:" msg
logNotResponse _addrs True _msg = pure ()
logInvalidEDNS addrs DNS.InvalidEDNS msg = putLog $ pprAddrs addrs ++ ":\n" ++ pprMessage "invalid EDNS:" msg
logInvalidEDNS _ _ _msg = pure ()
logHasError _addrs _rcode _msg = pure ()
pprAddrs = unwords . map show
putLog = logLn Log.WARN . (prefix ++)
{- FOURMOLU_ENABLE -}

ctrlFromRequestHeader :: DNSFlags -> EDNSheader -> QueryControls
ctrlFromRequestHeader reqF reqEH = DNS.doFlag doOp <> DNS.cdFlag cdOp <> DNS.adFlag adOp
where
Expand Down Expand Up @@ -151,7 +180,6 @@ queryErrorReply ident rqs left right qe = case qe of
NotResponse{} -> right $ message DNS.ServFail
InvalidEDNS{} -> right $ message DNS.ServFail
HasError _as rc _m -> right $ message rc
QueryDenied -> left "QueryDenied"
where
dnsError e = foldDNSErrorToRCODE (left $ "DNSError: " ++ show e) (right . message) e
message rc = replyDNSMessage ident rqs rc resFlags [] []
Expand Down
7 changes: 3 additions & 4 deletions dnsext-iterative/DNS/Iterative/Query/Types.hs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ module DNS.Iterative.Query.Types (
DFreshState (..),
runDNSQuery,
throwDnsError,
handleDnsError,
handleQueryError,
handleResponseError,
) where

Expand Down Expand Up @@ -164,7 +164,6 @@ data QueryError
| NotResponse [Address] Bool DNSMessage
| InvalidEDNS [Address] DNS.EDNSheader DNSMessage
| HasError [Address] DNS.RCODE DNSMessage
| QueryDenied
deriving (Show)

type ContextT m = ReaderT Env (ReaderT QueryContext m)
Expand All @@ -187,12 +186,12 @@ runDNSQuery q = runReaderT . runReaderT (runExceptT q)
throwDnsError :: DNSError -> DNSQuery a
throwDnsError = throwError . (`DnsError` [])

handleDnsError
handleQueryError
:: (QueryError -> DNSQuery a)
-> (a -> DNSQuery a)
-> DNSQuery a
-> DNSQuery a
handleDnsError left right q = either left right =<< lift (runExceptT q)
handleQueryError left right q = either left right =<< lift (runExceptT q)

-- example instances
-- - responseErrEither = handleResponseError Left Right :: DNSMessage -> Either QueryError DNSMessage
Expand Down
34 changes: 1 addition & 33 deletions dnsext-iterative/DNS/Iterative/Query/Utils.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ import Data.List.NonEmpty (toList)

-- dnsext packages
import qualified DNS.Log as Log
import DNS.Types (DNSError (..), DNSMessage (..))
import qualified DNS.Types as DNS
import DNS.Types (DNSMessage (..))
import Data.IP (IP (IPv4, IPv6))
import System.Console.ANSI.Types

Expand Down Expand Up @@ -46,37 +45,6 @@ pindents prefix (x:xs) = (prefix ++ ": " ++ x) : map indent xs
pprAddr :: Address -> String
pprAddr (ip, port) = show ip ++ "#" ++ show port

{- FOURMOLU_DISABLE -}
logQueryErrors :: String -> DNSQuery a -> DNSQuery a
logQueryErrors prefix q = do
handleDnsError left return q
where
left qe = do
logQueryError qe
throwError qe
logQueryError qe = case qe of
DnsError de ss -> logDnsError de ss
NotResponse addrs resp msg -> logNotResponse addrs resp msg
InvalidEDNS addrs eh msg -> logInvalidEDNS addrs eh msg
HasError addrs rcode msg -> logHasError addrs rcode msg
QueryDenied -> logQueryDenied
logDnsError de ss = case de of
NetworkFailure {} -> putLog detail
DecodeError {} -> putLog detail
RetryLimitExceeded -> putLog detail
UnknownDNSError {} -> putLog detail
_ -> pure ()
where detail = show de ++ ": " ++ intercalate ", " ss
logNotResponse addrs False msg = putLog $ pprAddrs addrs ++ ":\n" ++ pprMessage "not response:" msg
logNotResponse _addrs True _msg = pure ()
logInvalidEDNS addrs DNS.InvalidEDNS msg = putLog $ pprAddrs addrs ++ ":\n" ++ pprMessage "invalid EDNS:" msg
logInvalidEDNS _ _ _msg = pure ()
logHasError _addrs _rcode _msg = pure ()
logQueryDenied = pure ()
pprAddrs = unwords . map show
putLog = logLn Log.WARN . (prefix ++)
{- FOURMOLU_ENABLE -}

printResult :: Either QueryError DNSMessage -> IO ()
printResult = either print (putStr . pprMessage "result")

Expand Down

0 comments on commit da00016

Please sign in to comment.