Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions library/TensorFlow/DepTyped/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE TypeInType #-}
{-# LANGUAGE NoStarIsType #-}

module TensorFlow.DepTyped.Base (
KnownNatList(natListVal),
Expand All @@ -37,7 +38,8 @@ module TensorFlow.DepTyped.Base (
import GHC.TypeLits (Nat, KnownNat, natVal, type (*), Symbol, TypeError, ErrorMessage(Text, ShowType,
(:<>:)), type (-), type (+))
import Data.Proxy (Proxy(Proxy))
import Data.Promotion.Prelude (type If, type (:<), type (:>), type (:||), type (:==), type Reverse, type Length)
import Data.Singletons.Prelude (type If, type (<), type (>), type (||), type (==), type Reverse)
import Data.Singletons.Prelude.Foldable (type Length)
import Data.Kind (Constraint, Type)

class KnownNatList (ns :: [Nat]) where
Expand All @@ -64,11 +66,11 @@ type family AddPlaceholder (name :: Symbol) (shape :: [Nat]) (t :: Type) (placeh
AddPlaceholder n s t '[] = '[ '(n, s, t) ]
AddPlaceholder n s t ('(n, s, t) ': phs) = '(n, s, t) ': phs
AddPlaceholder n1 s1 t1 ('(n2, s2, t2) ': phs) =
If (n1 :< n2)
If (n1 < n2)
('(n1, s1, t1) ': '(n2, s2, t2) ': phs)
(If (n1 :> n2)
(If (n1 > n2)
('(n2, s2, t2) ': AddPlaceholder n1 s1 t1 phs)
(If (t1 :== t2)
(If (t1 == t2)
(TypeError ('Text "The placeholder " ':<>: 'ShowType n1 ':<>:
'Text " appears to have defined two different shapes " ':<>:
'ShowType s1 ':<>: 'Text " and " ':<>: 'ShowType s2))
Expand Down Expand Up @@ -115,9 +117,9 @@ type family BroadcastShapes' (revshape1::[Nat]) (revshape2::[Nat]) (shape1::[Nat
BroadcastShapes' '[] shape2 _ _ = shape2
BroadcastShapes' shape1 '[] _ _ = shape1
BroadcastShapes' (n:shape1) (m:shape2) origshape1 origshape2 =
If (n:==1 :|| n:==m)
If (n == 1 || n == m)
(m : BroadcastShapes' shape1 shape2 origshape1 origshape2)
(If (m:==1)
(If (m == 1)
(n : BroadcastShapes' shape1 shape2 origshape1 origshape2)
(TypeError ('Text "Error: shapes " ':<>: 'ShowType origshape1
':<>: 'Text " and " ':<>: 'ShowType origshape2
Expand Down
27 changes: 11 additions & 16 deletions stack.yaml
Original file line number Diff line number Diff line change
@@ -1,30 +1,25 @@
resolver: lts-11.18
resolver: lts-13.13

packages:
- '.'
- 'tensorflow-mnist-deptyped'
- location:
git: https://github.com/tensorflow/haskell.git
commit: c731a6f768c2f0285632b0b2097b2682a0c45861
extra-deps:
- git: https://github.com/tensorflow/haskell.git
commit: 26eebce98f1e5fc924c5141fbe3f400f8f0cfde8
subdirs:
- tensorflow
- tensorflow-core-ops
- tensorflow-logging
- tensorflow-mnist
- tensorflow-mnist-input-data
- tensorflow-opgen
- tensorflow-ops
- tensorflow-proto
- tensorflow-records
- tensorflow-records-conduit
- tensorflow-test
extra-dep: true
extra-deps:
- haskell-src-exts-1.19.1
- proto-lens-protobuf-types-0.2.2.0
- proto-lens-0.2.2.0
- proto-lens-protoc-0.2.2.3
- proto-lens-descriptors-0.2.2.0
- haskell-src-exts-1.21.0
- proto-lens-0.4.0.1
- proto-lens-protobuf-types-0.4.0.1
- proto-lens-runtime-0.4.0.2
- proto-lens-setup-0.4.0.2
- lens-labels-0.3.0.1
- proto-lens-protoc-0.4.0.2
- snappy-0.2.0.2
- snappy-framing-0.1.1

Expand Down
10 changes: 5 additions & 5 deletions tensorflow-deptyped.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ library
ghc-options: -Wall -Wincomplete-uni-patterns -Wincomplete-record-updates -Wmissing-import-lists
build-depends:
base >= 4.9 && < 5
, tensorflow
, tensorflow-ops
, tensorflow-core-ops
, singletons
, tensorflow >= 0.2.0.0 && < 0.3
, tensorflow-ops >= 0.2.0.0 && < 0.3
, tensorflow-core-ops >= 0.2.0.0 && < 0.3
, singletons >= 2.5 && < 2.6
, vector
, vector-sized
, bytestring
Expand Down Expand Up @@ -59,7 +59,7 @@ executable tensorflow-haskell-deptyped
, bytestring
, vector
, vector-sized
, singletons
, singletons >= 2.5 && < 2.6
, tensorflow
, tensorflow-ops
default-language: Haskell2010
Expand Down
1 change: 1 addition & 0 deletions tensorflow-mnist-deptyped/app/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE NoStarIsType #-}

import Control.Monad (forM_, forM, when)
import Control.Monad.IO.Class (liftIO)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE NoStarIsType #-}

module TensorFlow.Examples.MNISTDeptyped.Parse (
MNIST,
Expand Down
12 changes: 6 additions & 6 deletions tensorflow-mnist-deptyped/tensorflow-mnist-deptyped.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,17 @@ library
exposed-modules: TensorFlow.Examples.MNISTDeptyped.Parse
, TensorFlow.Examples.MNIST.TrainedGraph
other-modules: Paths_tensorflow_mnist_deptyped
build-depends: proto-lens == 0.2.*
build-depends: proto-lens == 0.4.*
, base >= 4.7 && < 5
, binary
, bytestring
, filepath
, lens-family
, containers
, split
, tensorflow-proto == 0.1.*
, tensorflow-core-ops == 0.1.*
, tensorflow
, tensorflow-proto >= 0.2.0.0 && < 0.3
, tensorflow-core-ops >= 0.2.0.0 && < 0.3
, tensorflow >= 0.2.0.0 && < 0.3
, text
, vector
, vector-sized
Expand All @@ -49,8 +49,8 @@ executable mnist-deptyped
, tensorflow
, tensorflow-mnist-deptyped
, tensorflow-mnist-input-data
, tensorflow-ops
, tensorflow-proto
, tensorflow-ops >= 0.2.0.0 && < 0.3
, tensorflow-proto >= 0.2.0.0 && < 0.3
, text
, transformers
, vector
Expand Down