Skip to content

Commit

Permalink
Fixed bug in type inference of vectorization
Browse files Browse the repository at this point in the history
  • Loading branch information
ilkka-torma committed Mar 1, 2018
1 parent 1c84258 commit 4f4b01e
Showing 1 changed file with 23 additions and 19 deletions.
42 changes: 23 additions & 19 deletions Infer.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ import Data.Set ((\\))
import qualified Data.Map as Map
import Data.List (nub, unzip3)
import Control.Monad.State
import Control.Monad (when, guard)
import Control.Monad (when, guard, forM_)

-- Possible results for enforcing a typeclass
data Enforce = Enforce {otherCons :: [TClass], -- "simpler" typeclass constraints
otherUnis :: [(Type, Type)]} -- types to be unified
deriving (Show)

-- Find a nesting depth at which list-nested t1 equals t2
eqDepth :: Type -> Type -> Maybe Int
Expand Down Expand Up @@ -65,7 +66,10 @@ holds c@(Vect2 t1 t2 t3 s1 s2 s3)
| TList _ <- t2 = Nothing
| TFun _ _ <- t1 = Nothing
| TFun _ _ <- t2 = Nothing
| TFun _ _ <- t3 = Nothing -- Lists and functions are not bi-vectorizable for now
| TFun _ _ <- t3 = Nothing
| TFun _ _ <- s1 = Nothing
| TFun _ _ <- s2 = Nothing
| TFun _ _ <- s3 = Nothing -- Lists and functions are not bi-vectorizable for now
| s1 == t1, s2 == t2, s3 == t3 = Just $ Enforce [] []
| Nothing <- uniDepth t1 s1 = Nothing
| Nothing <- uniDepth t2 s2 = Nothing
Expand All @@ -81,22 +85,22 @@ holds c@(Vect2 t1 t2 t3 s1 s2 s3)
| otherwise = Just $ Enforce [c] []

-- Default typeclass instances, given as unifiable pairs of types
defInst :: TClass -> [(Type, Type)]
defInst (Concrete t) = [(t, TConc TNum)]
defInst (Vect t1 t2 s1 s2) = [(s1, iterate TList t1 !! max 0 (n2 - n1)),
(s2, iterate TList t2 !! max 0 (n1 - n2))]
-- The choice is nondeterministic, which is modeled by a list of possibilities
defInst :: TClass -> [[(Type, Type)]]
defInst (Concrete t) = [[(t, TConc TNum)]]
defInst (Vect t1 t2 s1 s2) = [[(s1, iterate TList t1 !! max n1 n2)
,(s2, iterate TList t2 !! max n1 n2)]]
where Just n1 = uniDepth t1 s1
Just n2 = uniDepth t2 s2
defInst (Vect2 t1 t2 t3 s1 s2 s3)
| n1 >= n2 = [(s1, iterate TList t1 !! max 0 (n3 - n1)),
(s2, iterate TList t2 !! n2),
(s3, iterate TList t3 !! max 0 (n1 - n3))]
| otherwise = [(s1, iterate TList t1 !! n1),
(s2, iterate TList t2 !! max 0 (n3 - n2)),
(s3, iterate TList t3 !! max 0 (n2 - n3))]
defInst (Vect2 t1 t2 t3 s1 s2 s3) = [ [(s1, iterate TList t1 !! k1)
,(s2, iterate TList t2 !! k2)
,(s3, iterate TList t3 !! max k1 k2)]
| k1 <- [maxN, maxN-1 .. n1]
, k2 <- [maxN, maxN-1 .. n2]]
where Just n1 = uniDepth t1 s1
Just n2 = uniDepth t2 s2
Just n3 = uniDepth t3 s3
maxN = maximum [n1, n2, n3]

-- Type substitution: map from type vars to types
type Sub = Map.Map TLabel Type
Expand Down Expand Up @@ -267,7 +271,7 @@ unify t1 t2 = do
checkCons :: [TClass] -> Infer [TClass]
checkCons (x:_) | trace' 2 ("checking " ++ show x) False = undefined
checkCons [] = return []
checkCons (c:cs) = case {-traceShow' (c, holds c)-} holds c of
checkCons (c:cs) = case traceShow' 2 (c, holds c) $ holds c of
Just (Enforce newCs unis) -> do
mapM (uncurry unify) unis
(newCs ++) <$> checkCons cs
Expand Down Expand Up @@ -446,11 +450,11 @@ inferType constrainRes typeConstr exprs = trace' 1 ("inferring program " ++ show
when constrainRes $ do
CType conCons genType <- instantiate typeConstr
trace' 1 "applying constraints" $ unify genType typ
trace' 1 "defaulting instances" $ flip mapM_ (nub $ infCons ++ conCons) $ \con -> do
newCon <- checkCons =<< substitute [con]
case newCon of
[] -> return ()
[con'] -> mapM_ (uncurry unify) $ defInst con'
trace' 1 "defaulting instances" $ forM_ (nub $ infCons ++ conCons) $ \con -> do
newCons <- checkCons =<< substitute [con]
forM_ newCons $ \newCon -> do
insts <- lift $ defInst newCon
mapM_ (uncurry unify) insts
lExprs <- Map.assocs <$> gets lineExprs
flip mapM [(i, exp, typ) | (i, Processed exp typ) <- lExprs] $
\(i, exp, typ) -> do
Expand Down

0 comments on commit 4f4b01e

Please sign in to comment.