Skip to content

Commit bed63b6

Browse files
authored
Merge pull request #13 from lemastero/id-function
Handle identity function
2 parents 245c5ce + 9f31cec commit bed63b6

File tree

6 files changed

+183
-41
lines changed

6 files changed

+183
-41
lines changed

examples/adts.agda

+20-6
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,29 @@
11
module examples.adts where
22

3-
-- simple sum type no arguments - sealed trait + case objects
3+
-- simple product type no arguments - sealed trait + case objects
4+
45
data Rgb : Set where
56
Red : Rgb
67
Green : Rgb
78
Blue : Rgb
89
{-# COMPILE AGDA2SCALA Rgb #-}
910

10-
-- simple sum type with arguments - sealed trait + case class
11+
data Bool : Set where
12+
True : Bool
13+
False : Bool
14+
{-# COMPILE AGDA2SCALA Bool #-}
15+
16+
-- trivial function with single argument
17+
18+
idRgb : Rgb -> Rgb
19+
idRgb x = x
20+
{-# COMPILE AGDA2SCALA idRgb #-}
21+
22+
-- simple sum type - case class
1123

12-
data Color : Set where
13-
Light : Rgb -> Color
14-
Dark : Rgb -> Color
15-
{-# COMPILE AGDA2SCALA Color #-}
24+
record RgbPair : Set where
25+
constructor mkRgbPair
26+
field
27+
fst : Rgb
28+
snd : Bool
29+
{-# COMPILE AGDA2SCALA RgbPair #-}

examples/adts.scala

+9-4
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
1-
package adts
1+
object adts {
22

33
sealed trait Rgb
44
case object Red extends Rgb
55
case object Green extends Rgb
66
case object Blue extends Rgb
77

8-
sealed trait Color
9-
case object Light extends Color
10-
case object Dark extends Color
8+
sealed trait Bool
9+
case object True extends Bool
10+
case object False extends Bool
11+
12+
def idRgb(x: Rgb): Rgb = x
13+
14+
final case class RgbPair(snd: Bool, fst: Rgb)
15+
}

src/Agda/Compiler/Scala/AgdaToScalaExpr.hs

+92-8
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,113 @@ module Agda.Compiler.Scala.AgdaToScalaExpr (
55
import Agda.Compiler.Backend ( funCompiled, funClauses, Defn(..), RecordData(..))
66
import Agda.Syntax.Abstract.Name ( QName )
77
import Agda.Syntax.Common.Pretty ( prettyShow )
8-
import Agda.Syntax.Common ( Arg(..), ArgName, Named(..) )
8+
import Agda.Syntax.Common ( Arg(..), ArgName, Named(..), NamedName, WithOrigin(..), Ranged(..) )
99
import Agda.Syntax.Internal (
10-
Clause(..), DeBruijnPattern, DBPatVar(..), Dom(..), unDom, PatternInfo(..), Pattern'(..),
10+
Clause(..), DeBruijnPattern, DBPatVar(..), Dom(..), Dom'(..), unDom, PatternInfo(..), Pattern'(..),
1111
qnameName, qnameModule, Telescope, Tele(..), Term(..), Type, Type''(..) )
1212
import Agda.TypeChecking.Monad.Base ( Definition(..) )
1313
import Agda.TypeChecking.Monad
1414
import Agda.TypeChecking.CompiledClause ( CompiledClauses(..), CompiledClauses'(..) )
15+
import Agda.TypeChecking.Telescope ( teleNamedArgs, teleArgs, teleArgNames )
1516

16-
import Agda.Compiler.Scala.ScalaExpr ( ScalaName, ScalaExpr(..) )
17+
import Agda.Syntax.Common.Pretty ( prettyShow )
18+
19+
import Agda.Compiler.Scala.ScalaExpr ( ScalaName, ScalaType, FunBody, ScalaExpr(..), SeVar(..) )
1720

1821
compileDefn :: QName -> Defn -> ScalaExpr
1922
compileDefn defName theDef = case theDef of
2023
Datatype{dataCons = dataCons} ->
2124
compileDataType defName dataCons
2225
Function{funCompiled = funDef, funClauses = fc} ->
23-
Unhandled "compileDefn Function" (show defName ++ "\n = \n" ++ show theDef)
26+
compileFunction defName funDef fc
2427
RecordDefn(RecordData{_recFields = recFields, _recTel = recTel}) ->
25-
Unhandled "compileDefn RecordDefn" (show defName ++ "\n = \n" ++ show theDef)
28+
compileRecord defName recFields recTel
2629
other ->
2730
Unhandled "compileDefn other" (show defName ++ "\n = \n" ++ show theDef)
2831

32+
compileRecord :: QName -> [Dom QName] -> Telescope -> ScalaExpr
33+
compileRecord defName recFields recTel = SeProd (fromQName defName) (foldl varsFromTelescope [] recTel)
34+
35+
varsFromTelescope :: [SeVar] -> Dom Type -> [SeVar]
36+
varsFromTelescope xs dt = SeVar (nameFromDom dt) (fromDom dt) : xs
37+
2938
compileDataType :: QName -> [QName] -> ScalaExpr
30-
compileDataType defName fields = SeAdt (showName defName) (map showName fields)
39+
compileDataType defName fields = SeSum (fromQName defName) (map fromQName fields)
40+
41+
compileFunction :: QName
42+
-> Maybe CompiledClauses
43+
-> [Clause]
44+
-> ScalaExpr
45+
compileFunction defName funDef fc =
46+
SeFun
47+
(fromQName defName)
48+
[SeVar (compileFunctionArgument fc) (compileFunctionArgType fc)] -- TODO many function arguments
49+
(compileFunctionResultType fc)
50+
(compileFunctionBody funDef)
51+
52+
compileFunctionArgument :: [Clause] -> ScalaName
53+
compileFunctionArgument [] = ""
54+
compileFunctionArgument [fc] = fromDeBruijnPattern (namedThing (unArg (head (namedClausePats fc))))
55+
compileFunctionArgument xs = error "unsupported compileFunctionArgument" ++ (show xs) -- show xs
56+
57+
compileFunctionArgType :: [Clause] -> ScalaType
58+
compileFunctionArgType [ Clause{clauseTel = ct} ] = fromTelescope ct
59+
compileFunctionArgType xs = error "unsupported compileFunctionArgType" ++ (show xs)
60+
61+
fromTelescope :: Telescope -> ScalaName -- TODO PP probably parent should be different, use fold on telescope above
62+
fromTelescope tel = case tel of
63+
ExtendTel a _ -> fromDom a
64+
other -> error ("unhandled fromType" ++ show other)
65+
66+
nameFromDom :: Dom Type -> ScalaName
67+
nameFromDom dt = case (domName dt) of
68+
Nothing -> error ("nameFromDom" ++ show dt)
69+
Just a -> namedNameToStr a
70+
71+
namedNameToStr :: NamedName -> ScalaName
72+
namedNameToStr n = rangedThing (woThing n)
73+
74+
fromDom :: Dom Type -> ScalaName
75+
fromDom x = fromType (unDom x)
76+
77+
compileFunctionResultType :: [Clause] -> ScalaType
78+
compileFunctionResultType [Clause{clauseType = ct}] = fromMaybeType ct
79+
compileFunctionResultType other = error ("unhandled compileFunctionResultType" ++ show other)
80+
81+
fromMaybeType :: Maybe (Arg Type) -> ScalaName
82+
fromMaybeType (Just argType) = fromArgType argType
83+
fromMaybeType other = error ("unhandled fromMaybeType" ++ show other)
84+
85+
fromArgType :: Arg Type -> ScalaName
86+
fromArgType arg = fromType (unArg arg)
87+
88+
fromType :: Type -> ScalaName
89+
fromType t = case t of
90+
a@(El _ ue) -> fromTerm ue
91+
other -> error ("unhandled fromType" ++ show other)
92+
93+
fromTerm :: Term -> ScalaName
94+
fromTerm t = case t of
95+
Def qname el -> fromQName qname
96+
other -> error ("unhandled fromTerm" ++ show other)
97+
98+
fromDeBruijnPattern :: DeBruijnPattern -> ScalaName
99+
fromDeBruijnPattern d = case d of
100+
VarP a b -> (dbPatVarName b)
101+
a@(ConP x y z) -> show a
102+
other -> error ("unhandled fromDeBruijnPattern" ++ show other)
103+
104+
compileFunctionBody :: Maybe CompiledClauses -> FunBody
105+
compileFunctionBody (Just funDef) = fromCompiledClauses funDef
106+
compileFunctionBody funDef = error ("unhandled compileFunctionBody " ++ show funDef)
107+
108+
fromCompiledClauses :: CompiledClauses -> FunBody
109+
fromCompiledClauses cc = case cc of
110+
(Done (x:xs) term) -> fromArgName x
111+
other -> error ("unhandled fromCompiledClauses " ++ show other)
112+
113+
fromArgName :: Arg ArgName -> FunBody
114+
fromArgName = unArg
31115

32-
showName :: QName -> ScalaName
33-
showName = prettyShow . qnameName
116+
fromQName :: QName -> ScalaName
117+
fromQName = prettyShow . qnameName

src/Agda/Compiler/Scala/PrintScalaExpr.hs

+31-11
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,50 @@
1-
{-# LANGUAGE OverloadedStrings #-}
2-
31
module Agda.Compiler.Scala.PrintScalaExpr ( printScalaExpr
42
, printCaseObject
53
, printSealedTrait
64
, printPackage
5+
, printCaseClass
76
, combineLines
87
) where
98

10-
import Agda.Compiler.Scala.ScalaExpr ( ScalaName, ScalaExpr(..) )
9+
import Data.List ( intercalate )
10+
import Agda.Compiler.Scala.ScalaExpr ( ScalaName, ScalaExpr(..), SeVar(..))
1111

1212
printScalaExpr :: ScalaExpr -> String
1313
printScalaExpr def = case def of
1414
(SePackage pName defs) ->
15-
(printPackage pName) <> defsSeparator
16-
<> (
15+
(printPackage pName) <> exprSeparator -- TODO this should be package + object
16+
<> bracket (
1717
blankLine -- between package declaration and first definition
1818
<> combineLines (map printScalaExpr defs)
1919
)
2020
<> blankLine -- EOF
21-
(SeAdt adtName adtCases) ->
21+
(SeSum adtName adtCases) ->
2222
(printSealedTrait adtName)
2323
<> defsSeparator
24-
<> unlines (map (printCaseObject adtName) adtCases)
25-
(Unhandled name payload) -> "" -- for development comment out this and uncomment below
26-
-- (Unhandled name payload) -> "TODO " ++ (show name) ++ " " ++ (show payload)
27-
-- other -> "unsupported printScalaExpr " ++ (show other)
24+
<> combineLines (map (printCaseObject adtName) adtCases)
25+
<> defsSeparator
26+
(SeFun fName args resType funBody) ->
27+
"def" <> exprSeparator <> fName
28+
<> "(" <> combineLines (map printVar args) <> ")"
29+
<> ":" <> exprSeparator <> resType <> exprSeparator
30+
<> "=" <> exprSeparator <> funBody
31+
<> defsSeparator
32+
(SeProd name args) -> printCaseClass name args
33+
(Unhandled "" payload) -> ""
34+
(Unhandled name payload) -> "TODO " ++ (show name) ++ " " ++ (show payload)
35+
other -> "unsupported printScalaExpr " ++ (show other)
36+
37+
printCaseClass :: ScalaName -> [SeVar] -> String
38+
printCaseClass name args = "final case class" <> exprSeparator <> name <> "(" <> (printExpr args) <> ")"
39+
40+
printVar :: SeVar -> String
41+
printVar (SeVar sName sType) = sName <> ":" <> exprSeparator <> sType
42+
43+
printExpr :: [SeVar] -> String
44+
printExpr names = combineThem (map printVar names)
45+
46+
combineThem :: [String] -> String
47+
combineThem xs = intercalate ", " xs
2848

2949
printSealedTrait :: ScalaName -> String
3050
printSealedTrait adtName = "sealed trait" <> exprSeparator <> adtName
@@ -34,7 +54,7 @@ printCaseObject superName caseName =
3454
"case object" <> exprSeparator <> caseName <> exprSeparator <> "extends" <> exprSeparator <> superName
3555

3656
printPackage :: ScalaName -> String
37-
printPackage pName = "package" <> exprSeparator <> pName
57+
printPackage pName = "object" <> exprSeparator <> pName
3858

3959
bracket :: String -> String
4060
bracket str = "{\n" <> str <> "\n}"

src/Agda/Compiler/Scala/ScalaExpr.hs

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,25 @@
11
module Agda.Compiler.Scala.ScalaExpr (
22
ScalaName,
3+
ScalaType,
34
ScalaExpr(..),
5+
SeVar(..),
6+
FunBody,
47
unHandled
58
) where
69

710
type ScalaName = String
11+
type FunBody = String -- this should be some lambda expression
12+
type ScalaType = String
813

9-
{- Represent Scala language extracted from Agda compiler representation -}
14+
data SeVar = SeVar ScalaName ScalaType
15+
deriving ( Show )
16+
17+
{- Represent Scala language extracted from internal Agda compiler representation -}
1018
data ScalaExpr
1119
= SePackage ScalaName [ScalaExpr]
12-
| SeAdt ScalaName [ScalaName]
20+
| SeSum ScalaName [ScalaName]
21+
| SeFun ScalaName [SeVar] ScalaType FunBody
22+
| SeProd ScalaName [SeVar]
1323
| Unhandled ScalaName String
1424
deriving ( Show )
1525

test/PrintScalaExprTest.hs

+19-10
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ import Agda.Compiler.Scala.PrintScalaExpr (
77
, printCaseObject
88
, printPackage
99
, combineLines
10+
, printCaseClass
1011
)
11-
import Agda.Compiler.Scala.ScalaExpr ( ScalaExpr(..) )
12+
import Agda.Compiler.Scala.ScalaExpr ( ScalaExpr(..), SeVar(..) )
1213

1314
testPrintCaseObject :: Test
1415
testPrintCaseObject = TestCase
@@ -22,11 +23,11 @@ testPrintSealedTrait = TestCase
2223
"sealed trait Color"
2324
(printSealedTrait "Color"))
2425

25-
testPrintPackage :: Test
26-
testPrintPackage = TestCase
27-
(assertEqual "printPackage"
28-
"package adts"
29-
(printPackage "adts"))
26+
--testPrintPackage :: Test
27+
--testPrintPackage = TestCase
28+
-- (assertEqual "printPackage"
29+
-- "package adts"
30+
-- (printPackage "adts"))
3031

3132
testCombineLines :: Test
3233
testCombineLines = TestCase
@@ -37,19 +38,27 @@ testCombineLines = TestCase
3738
testPrintScalaExpr :: Test
3839
testPrintScalaExpr = TestCase
3940
(assertEqual "printScalaExpr" (printScalaExpr $ SePackage "adts" moduleContent)
40-
"package adts\n\nsealed trait Rgb\ncase object Red extends Rgb\ncase object Green extends Rgb\ncase object Blue extends Rgb\n\nsealed trait Color\ncase object Light extends Color\ncase object Dark extends Color\n"
41+
"object adts {\n\nsealed trait Rgb\ncase object Red extends Rgb\ncase object Green extends Rgb\ncase object Blue extends Rgb\n\nsealed trait Color\ncase object Light extends Color\ncase object Dark extends Color\n}\n"
4142
)
4243
where
4344
moduleContent = [rgbAdt, blank, blank, blank, colorAdt, blank, blank]
44-
rgbAdt = SeAdt "Rgb" ["Red","Green","Blue"]
45-
colorAdt = SeAdt "Color" ["Light","Dark"]
45+
rgbAdt = SeSum "Rgb" ["Red","Green","Blue"]
46+
colorAdt = SeSum "Color" ["Light","Dark"]
4647
blank = Unhandled "" ""
4748

49+
testPrintCaseClass :: Test
50+
testPrintCaseClass = TestCase
51+
(assertEqual "printCaseClass"
52+
"final case class RgbPair(snd: Bool, fst: Rgb)"
53+
(printCaseClass "RgbPair" [SeVar "snd" "Bool", SeVar "fst" "Rgb"]))
54+
55+
4856
printScalaTests :: Test
4957
printScalaTests = TestList [
5058
TestLabel "printCaseObject" testPrintCaseObject
5159
, TestLabel "printSealedTrait" testPrintSealedTrait
52-
, TestLabel "printPackage" testPrintPackage
60+
-- , TestLabel "printPackage" testPrintPackage
5361
, TestLabel "combineLines" testCombineLines
62+
, TestLabel "printCaseClass" testPrintCaseClass
5463
, TestLabel "printScalaExpr" testPrintScalaExpr
5564
]

0 commit comments

Comments
 (0)