Skip to content

Commit

Permalink
experimental/graph: add exp, cond, and more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jxy committed Jul 18, 2024
1 parent 737ffd0 commit e6cf14c
Show file tree
Hide file tree
Showing 6 changed files with 418 additions and 85 deletions.
4 changes: 2 additions & 2 deletions src/experimental/graph/core.nim
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
TODO
- if
- function/lambda
]#
Expand Down Expand Up @@ -73,7 +72,7 @@ proc nodeRepr*(x: Gvalue): string =
if f != nil:
result &= " " & $f & "@0X" & strip(toHex(cast[int](f)), trailing = false, chars = {'0'})

method newOneOf*(x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("newOneOf(" & $x & ")")
method newOneOf*(x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("newOneOf(" & $x & ")") ## Be sure to zero init fields
method valCopy*(z: Gvalue, x: Gvalue) {.base.} = raiseErrorBaseMethod("valCopy(" & $z & "," & $x & ")")

proc assignGvalue(z: Gvalue, x: Gvalue) =
Expand Down Expand Up @@ -145,6 +144,7 @@ method `+`*(x: Gvalue, y: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("`+`("
method `*`*(x: Gvalue, y: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("`*`(" & $x & ", " & $y & ")")
method `-`*(x: Gvalue, y: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("`-`(" & $x & ", " & $y & ")")
method `/`*(x: Gvalue, y: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("`/`(" & $x & ", " & $y & ")")
method exp*(x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("exp(" & $x & ")")

proc updated*(x: Gvalue) =
var epoch {.global.} = 0
Expand Down
26 changes: 15 additions & 11 deletions src/experimental/graph/gauge.nim
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ proc toGvalue*(x: Gauge): Ggauge =
result = Ggauge(gval: x)
result.updated

method newOneOf*(x: Ggauge): Gvalue = Ggauge(gval: x.gval.newOneOf)
method newOneOf*(x: Ggauge): Gvalue =
let g = x.gval.newOneOf
threads:
for f in g:
f := 0.0
Ggauge(gval: g)
method valCopy*(z: Ggauge, x: Ggauge) =
let u = z.gval
let v = x.gval
Expand All @@ -38,7 +43,6 @@ method retr*(x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("retr(" & $x & "
method adj*(x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("adj(" & $x & ")")
method norm2*(x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("norm2(" & $x & ")")
method redot*(x: Gvalue, y: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("redot(" & $x & "," & $y & ")")
method exp*(x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("exp(" & $x & ")")
method expDeriv*(b: Gvalue, x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("expDeriv(" & $b & "," & $x & ")")
method projTAH*(x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("projTAH(" & $x & ")")

Expand Down Expand Up @@ -133,7 +137,7 @@ proc neggb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue =

let negg = newGfunc(forward = neggf, backward = neggb, name = "-g")

method `-`(x: Ggauge): Gvalue = Ggauge(gval: x.gval.newOneOf, inputs: @[Gvalue(x)], gfunc: negg)
method `-`*(x: Ggauge): Gvalue = Ggauge(gval: x.gval.newOneOf, inputs: @[Gvalue(x)], gfunc: negg)

proc addsgb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue =
if zb == nil:
Expand All @@ -156,7 +160,7 @@ proc addsgf(v: Gvalue) =

let addsg = newGfunc(forward = addsgf, backward = addsgb, name = "s+g")

method `+`(x: Gscalar, y: Ggauge): Gvalue = Ggauge(gval: y.gval.newOneOf, inputs: @[Gvalue(x), y], gfunc: addsg)
method `+`*(x: Gscalar, y: Ggauge): Gvalue = Ggauge(gval: y.gval.newOneOf, inputs: @[Gvalue(x), y], gfunc: addsg)

proc addggb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue =
if zb == nil:
Expand All @@ -174,7 +178,7 @@ proc addggf(v: Gvalue) =

let addgg = newGfunc(forward = addggf, backward = addggb, name = "g+g")

method `+`(x: Ggauge, y: Ggauge): Gvalue = Ggauge(gval: x.gval.newOneOf, inputs: @[Gvalue(x), y], gfunc: addgg)
method `+`*(x: Ggauge, y: Ggauge): Gvalue = Ggauge(gval: x.gval.newOneOf, inputs: @[Gvalue(x), y], gfunc: addgg)

proc mulsgb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue =
if zb == nil:
Expand All @@ -197,7 +201,7 @@ proc mulsgf(v: Gvalue) =

let mulsg = newGfunc(forward = mulsgf, backward = mulsgb, name = "s*g")

method `*`(x: Gscalar, y: Ggauge): Gvalue = Ggauge(gval: y.gval.newOneOf, inputs: @[Gvalue(x), y], gfunc: mulsg)
method `*`*(x: Gscalar, y: Ggauge): Gvalue = Ggauge(gval: y.gval.newOneOf, inputs: @[Gvalue(x), y], gfunc: mulsg)

proc mulggb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue =
if zb == nil:
Expand All @@ -220,7 +224,7 @@ proc mulggf(v: Gvalue) =

let mulgg = newGfunc(forward = mulggf, backward = mulggb, name = "g*g")

method `*`(x: Ggauge, y: Ggauge): Gvalue = Ggauge(gval: x.gval.newOneOf, inputs: @[Gvalue(x), y], gfunc: mulgg)
method `*`*(x: Ggauge, y: Ggauge): Gvalue = Ggauge(gval: x.gval.newOneOf, inputs: @[Gvalue(x), y], gfunc: mulgg)

proc redotggb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue =
case i
Expand Down Expand Up @@ -249,7 +253,7 @@ proc redotggf(v: Gvalue) =

let redotgg = newGfunc(forward = redotggf, backward = redotggb, name = "redotgg")

method redot(x: Ggauge, y: Ggauge): Gvalue = Gscalar(inputs: @[Gvalue(x), y], gfunc: redotgg)
method redot*(x: Ggauge, y: Ggauge): Gvalue = Gscalar(inputs: @[Gvalue(x), y], gfunc: redotgg)

proc subgsb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue =
if zb == nil:
Expand All @@ -272,7 +276,7 @@ proc subgsf(v: Gvalue) =

let subgs = newGfunc(forward = subgsf, backward = subgsb, name = "g-s")

method `-`(x: Ggauge, y: Gscalar): Gvalue = Ggauge(gval: x.gval.newOneOf, inputs: @[Gvalue(x), y], gfunc: subgs)
method `-`*(x: Ggauge, y: Gscalar): Gvalue = Ggauge(gval: x.gval.newOneOf, inputs: @[Gvalue(x), y], gfunc: subgs)

proc subggb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue =
if zb == nil:
Expand All @@ -295,7 +299,7 @@ proc subggf(v: Gvalue) =

let subgg = newGfunc(forward = subggf, backward = subggb, name = "g-g")

method `-`(x: Ggauge, y: Ggauge): Gvalue = Ggauge(gval: x.gval.newOneOf, inputs: @[Gvalue(x), y], gfunc: subgg)
method `-`*(x: Ggauge, y: Ggauge): Gvalue = Ggauge(gval: x.gval.newOneOf, inputs: @[Gvalue(x), y], gfunc: subgg)

proc expgb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue =
if zb == nil:
Expand Down Expand Up @@ -531,7 +535,7 @@ proc redotccf(v: Gvalue) =

let redotcc = newGfunc(forward = redotccf, backward = redotccb, name = "redotcc")

method redot(x: Gactcoeff, y: Gactcoeff): Gvalue = Gscalar(inputs: @[Gvalue(x), y], gfunc: redotcc)
method redot*(x: Gactcoeff, y: Gactcoeff): Gvalue = Gscalar(inputs: @[Gvalue(x), y], gfunc: redotcc)

const
C1Symanzik = -1.0/12.0 # tree-level
Expand Down
27 changes: 3 additions & 24 deletions src/experimental/graph/multi.nim
Original file line number Diff line number Diff line change
@@ -1,30 +1,9 @@
import core, scalar

type
Gint* {.final.} = ref object of Gvalue
ival: int
Gmulti* {.final.} = ref object of Gvalue
mval: seq[Gvalue]

proc getint*(x: Gvalue): int = Gint(x).ival

proc `getint=`*(x: Gvalue, y: int) =
let xs = Gint(x)
xs.ival = y

proc update*(x: Gvalue, y: int) =
x.getint = y
x.updated

converter toGvalue*(x: int): Gvalue =
result = Gint(ival: x)
result.updated

method newOneOf*(x: Gint): Gvalue = Gint()
method valCopy*(z: Gint, x: Gint) = z.ival = x.ival

method `$`*(x: Gint): string = $x.ival

proc getmulti*(x: Gvalue): seq[Gvalue] = Gmulti(x).mval

proc `getmulti=`*(x: Gvalue, y: seq[Gvalue]) =
Expand Down Expand Up @@ -58,7 +37,7 @@ method updateAt*(x: Gvalue, i: Gvalue, y: Gvalue): Gvalue {.base.} = raiseErrorB
proc getAtmb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue =
case i
of 0:
return z.inputs[0].updateAt(i, zb)
return z.inputs[0].newOneOf.updateAt(i, zb)
else:
raiseValueError("i must be 0, got: " & $i)

Expand All @@ -70,14 +49,14 @@ proc getAtmf(v: Gvalue) =
let getAtm = newGfunc(forward = getAtmf, backward = getAtmb, name = "getAtm")

method `[]`*(x: Gmulti, i: Gint): Gvalue =
result = newOneOf x.mval[i.ival]
result = newOneOf x.mval[i.getint]
result.inputs = @[Gvalue(x), i]
result.gfunc = getAtm

proc updateAtmb(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue =
case i
of 0:
return zb.updateAt(i, 0.0) # TODO: need a zero method
return zb.updateAt(i, zb.inputs[0].newOneOf)
of 2:
return zb[i]
else:
Expand Down
Loading

0 comments on commit e6cf14c

Please sign in to comment.