Skip to content

Commit

Permalink
add force solver statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
James Osborn committed Jul 9, 2024
1 parent 54f251b commit e9fa6a3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 12 deletions.
33 changes: 25 additions & 8 deletions src/experimental/stagag.nim
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,19 @@ var spa = initSolverParams()
spa.r2req = arsq
spa.maxits = 10000
#spa.backend = sbQex
var spf = initSolverParams()
#spf.subsetName = "even"
spf.r2req = frsq
spf.maxits = 10000
spf.verbosity = 0
#spf.backend = sbQex
var spf = newSeq[type spa](hmasses.len+1) # fermion force forward
var spfb = newSeq[type spa](hmasses.len+1) # fermion force backward
for i in 0..<spf.len:
spf[i] = initSolverParams()
spf[i].r2req = frsq
spf[i].maxits = 10000
spf[i].verbosity = 0
#spf.backend = sbQex
spfb[i] = initSolverParams()
spfb[i].r2req = frsq
spfb[i].maxits = 10000
spfb[i].verbosity = 0
#spf.backend = sbQex

proc norm2*(x: seq): float =
for i in 0..<x.len:
Expand Down Expand Up @@ -343,9 +350,9 @@ proc addFf(g: GaugeV, i = 0) =
pushCvec()
let cv = cvecvs[^1]
if i == 0:
stag.agradSolve(g, cv, phiv[i], mass, spf)
stag.agradSolve(g, cv, phiv[i], mass, spf[i], spfb[i])
else:
stag.agradSolve(g, cv, phiv[i], vhmasses[i-1], spf)
stag.agradSolve(g, cv, phiv[i], vhmasses[i-1], spf[i], spfb[i])
pushMom()
stag.agradStagDeriv(momvs[^1], cv)
pushMom()
Expand Down Expand Up @@ -1968,6 +1975,9 @@ alwaysAccept = false
#gutime = 0.0
#gftime = 0.0
#fftime = 0.0
for i in 0..<spf.len:
spf[i].resetStats
spfb[i].resetStats
block:
tic()
for n in 1..ntrain:
Expand All @@ -1977,6 +1987,8 @@ block:
m.update
getGrad(m)
let tup = getElapsedTime()
for i in 0..<spf.len:
echo &"FFits{i}: ", spf[i].getAveStats
measure()
if upit > 0:
if n mod upit == 0:
Expand All @@ -1992,6 +2004,9 @@ block:
#echo &"gu: {gutime} gf: {gftime} ff: {fftime} ot: {et-at} tt: {et}"

resetMeasure()
for i in 0..<spf.len:
spf[i].resetStats
spfb[i].resetStats
if trajs > 0:
m.clearStats
pacc.clear
Expand All @@ -2006,6 +2021,8 @@ if trajs > 0:
#echo "cost: ", nff/(vtau.obj*vtau.obj*m.avgPAccept)
echo "cost: ", getCost0(m)
let tup = getElapsedTime()
for i in 0..<spf.len:
echo &"FFits{i}: ", spf[i].getAveStats
measure()
let ttot = getElapsedTime()
echo "End inference update: ", tup, " measure: ", ttot-tup, " total: ", ttot
Expand Down
11 changes: 7 additions & 4 deletions src/hmc/agradOps.nim
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ proc agradSolvebck[I,O](op: AgOp[I,O]) {.nimcall.} =
let g = op.inputs[1]
let x = op.inputs[2]
let m = op.inputs[3]
let p = op.inputs[4]
let p = op.inputs[5]
let r = op.outputs
var c = r.grad.newOneOf
r.grad.odd *= -1
Expand All @@ -583,12 +583,15 @@ proc agradSolvebck[I,O](op: AgOp[I,O]) {.nimcall.} =
when m is AgVar:
if m.doGrad:
m.grad += redot(r.obj, c)
proc agradSolve(c: var AgTape, s,g,r,x,m,p: auto) =
var op = newAgOp((s,g,x,m,p), r, agradSolvefwd, agradSolvebck)
proc agradSolve(c: var AgTape, s,g,r,x,m,pf,pb: auto) =
var op = newAgOp((s,g,x,m,pf,pb), r, agradSolvefwd, agradSolvebck)
c.add op
template agradSolve*(s: Staggered, g,r,x,m,pf,pb: auto) =
## g: gauge, r: result, x: src, m: mass, p: solve params
r.ctx.agradSolve(s, g, r, x, m, addr pf, addr pb)
template agradSolve*(s: Staggered, g,r,x,m,p: auto) =
## g: gauge, r: result, x: src, m: mass, p: solve params
r.ctx.agradSolve(s, g, r, x, m, addr p)
r.ctx.agradSolve(s, g, r, x, m, addr p, addr p)

when isMainModule:
import qex, physics/stagSolve
Expand Down

0 comments on commit e9fa6a3

Please sign in to comment.