Skip to content

Commit c4c5c2a

Browse files
committed
improve output and tests
1 parent 8eef2d2 commit c4c5c2a

File tree

5 files changed

+41
-7
lines changed

5 files changed

+41
-7
lines changed

Diff for: src/cg.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ function optimize(fg, x, alg::ConjugateGradient;
169169
numiter, t, f, normgrad)
170170
else
171171
verbosity >= 1 &&
172-
@warn @sprintf("CG: not converged to requested tol: f = %.12f, ‖∇f‖ = %.4e",
173-
f, normgrad)
172+
@warn @sprintf("CG: not converged to requested tol after %d iterations and time %.2f s: f = %.12f, ‖∇f‖ = %.4e",
173+
numiter, t, f, normgrad)
174174
end
175175
history = [fhistory normgradhistory]
176176
return x, f, g, numfg, history

Diff for: src/gd.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ function optimize(fg, x, alg::GradientDescent;
117117
numiter, t, f, normgrad)
118118
else
119119
verbosity >= 1 &&
120-
@warn @sprintf("GD: not converged to requested tol: f = %.12f, ‖∇f‖ = %.4e",
121-
f, normgrad)
120+
@warn @sprintf("GD: not converged to requested tol after %d iterations and time %.2f s: f = %.12f, ‖∇f‖ = %.4e",
121+
numiter, t, f, normgrad)
122122
end
123123
history = [fhistory normgradhistory]
124124
return x, f, g, numfg, history

Diff for: src/lbfgs.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ function optimize(fg, x, alg::LBFGS;
194194
numiter, t, f, normgrad)
195195
else
196196
verbosity >= 1 &&
197-
@warn @sprintf("LBFGS: not converged to requested tol: f = %.12f, ‖∇f‖ = %.4e",
198-
f, normgrad)
197+
@warn @sprintf("LBFGS: not converged to requested tol after %d iterations and time %.2f s: f = %.12f, ‖∇f‖ = %.4e",
198+
numiter, t, f, normgrad)
199199
end
200200
history = [fhistory normgradhistory]
201201
return x, f, g, numfg, history

Diff for: src/tangentvector.jl

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
using VectorInterface
2+
3+
struct TangentVector{M,T,F1,F2,F3}
4+
x::M
5+
v::T
6+
inner::F1
7+
add!!::F2
8+
scale!!::F3
9+
end
10+
11+
Base.getindex(tv::TangentVector) = tv.v
12+
base(tv::TangentVector) = tv.x
13+
function checkbase(tv1::TangentVector, tv2::TangentVector)
14+
return tv1.x === tv2.x ? tv1.x :
15+
throw(ArgumentError("tangent vectors with different base points"))
16+
end
17+
18+
function VectorInterface.scale!!(tv::TangentVector, α::Real)
19+
tv.v = tv.scale!!(tv.v, α)
20+
return tv
21+
end
22+
function VectorInterface.add!!(tv::TV, α::Real, tv2::TV) where {TV<:TangentVector}
23+
checkbase(tv, tv2)
24+
tv.v = tv.add!!(tv.v, α, tv2.v)
25+
return tv
26+
end
27+
function VectorInterface.inner(tv::TV, tv2::TV) where {TV<:TangentVector}
28+
return tv.inner(checkbase(tv, tv2), tv.v, tv2.v)
29+
end
30+
31+
function retract(x::M, η::TangentVector{M}, α::Real) where {M}
32+
33+
return x, η
34+
end

Diff for: test/runtests.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ algorithms = (GradientDescent, ConjugateGradient, LBFGS)
5757
A = randn(n, n)
5858
fg = quadraticproblem(A' * A, y)
5959
x₀ = randn(n)
60-
alg = algtype(; verbosity=2, gradtol=1e-12)
60+
alg = algtype(; verbosity=2, gradtol=1e-9, maxiter=10_000_000)
6161
x, f, g, numfg, normgradhistory = optimize(fg, x₀, alg)
6262
@test x y
6363
@test f < 1e-14

0 commit comments

Comments
 (0)