Skip to content

Commit

Permalink
updating the code and status so we have a flag for nonPositive curv a…
Browse files Browse the repository at this point in the history
…nd also
  • Loading branch information
farhadrclass committed Mar 4, 2025
1 parent b3d8f21 commit b60e77c
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 20 deletions.
13 changes: 8 additions & 5 deletions src/krylov_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -112,13 +112,14 @@ mutable struct MinresSolver{T,FC,S} <: KrylovSolver{T,FC,S}
x :: S
r1 :: S
r2 :: S
rk :: S
w1 :: S
w2 :: S
y :: S
v :: S
err_vec :: Vector{T}
warm_start :: Bool
stats :: SimpleStats{T}
stats :: conStats{T}
end

function MinresSolver(kc::KrylovConstructor; window :: Int=5)
Expand All @@ -131,13 +132,14 @@ function MinresSolver(kc::KrylovConstructor; window :: Int=5)
x = similar(kc.vn)
r1 = similar(kc.vn)
r2 = similar(kc.vn)
rk = similar(kc.vn_empty)
w1 = similar(kc.vn)
w2 = similar(kc.vn)
y = similar(kc.vn)
v = similar(kc.vn_empty)
err_vec = zeros(T, window)
stats = SimpleStats(0, false, false, T[], T[], T[], 0.0, "unknown")
solver = MinresSolver{T,FC,S}(m, n, Δx, x, r1, r2, w1, w2, y, v, err_vec, false, stats)
stats = conStats(0, false, false, false, false, T[], T[], T[], 0.0, "unknown")
solver = MinresSolver{T,FC,S}(m, n, Δx, x, r1, r2, rk, w1, w2, y, v, err_vec, false, stats)
return solver
end

Expand All @@ -148,13 +150,14 @@ function MinresSolver(m, n, S; window :: Int=5)
x = S(undef, n)
r1 = S(undef, n)
r2 = S(undef, n)
rk = S(undef, 0)
w1 = S(undef, n)
w2 = S(undef, n)
y = S(undef, n)
v = S(undef, 0)
err_vec = zeros(T, window)
stats = SimpleStats(0, false, false, T[], T[], T[], 0.0, "unknown")
solver = MinresSolver{T,FC,S}(m, n, Δx, x, r1, r2, w1, w2, y, v, err_vec, false, stats)
stats = conStats(0, false, false, false, false, T[], T[], T[], 0.0, "unknown")
solver = MinresSolver{T,FC,S}(m, n, Δx, x, r1, r2, rk, w1, w2, y, v, err_vec, false, stats)
return solver
end

Expand Down
52 changes: 52 additions & 0 deletions src/krylov_stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,58 @@ function copyto!(dest :: SimpleStats, src :: SimpleStats)
return dest
end

"""
Type for storing statistics returned by Conjugate Methods.
Methods icludes:
- CG (TODO)
- CR (TODO)
- MINRES
The fields are as follows:
- niter
- solved
- nonposi_curv: when a non-positive curvature is detected
- linesearch: when a line search is performed
- inconsistent
- residuals
- Aresiduals
- Acond
- timer
- status
"""
mutable struct conStats{T} <: KrylovStats{T}
niter :: Int
solved :: Bool
nonposi_curv :: Bool
linesearch :: Bool
inconsistent :: Bool
residuals :: Vector{T}
Aresiduals :: Vector{T}
Acond :: Vector{T}
timer :: Float64
status :: String
end

function reset!(stats :: conStats)
empty!(stats.residuals)
empty!(stats.Aresiduals)
empty!(stats.Acond)
end

function copyto!(dest :: conStats, src :: conStats)
dest.niter = src.niter
dest.solved = src.solved
dest.nonposi_curv = src.nonposi_curv
dest.linesearch = src.linesearch
dest.inconsistent = src.inconsistent
dest.residuals = copy(src.residuals)
dest.Aresiduals = copy(src.Aresiduals)
dest.Acond = copy(src.Acond)
dest.timer = src.timer
dest.status = src.status
return dest
end


"""
Type for storing statistics returned by LSMR.
The fields are as follows:
Expand Down
38 changes: 23 additions & 15 deletions src/minres.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# Brussels, Belgium, June 2015.
# Montréal, August 2015.
#
# Liu, Yang & Roosta, Fred. (2022). A Newton-MR algorithm with complexity guarantees for nonconvex smooth unconstrained optimization. 10.48550/arXiv.2208.07095.
# Liu, Yang, and Fred Roosta. "MINRES: from negative curvature detection to monotonicity properties." SIAM Journal on Optimization 32, no. 4 (2022): 2636-2661.

export minres, minres!

Expand Down Expand Up @@ -154,12 +154,16 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm

# Set up workspace.
allocate_if(!MisI, solver, :v, S, solver.x) # The length of v is n
allocate_if(linesearch, solver, :rk, S, solver.x) # The length of rk is n
Δx, x, r1, r2, w1, w2, y = solver.Δx, solver.x, solver.r1, solver.r2, solver.w1, solver.w2, solver.y
err_vec, stats = solver.err_vec, solver.stats
warm_start = solver.warm_start
rNorms, ArNorms, Aconds = stats.residuals, stats.Aresiduals, stats.Acond
reset!(stats)
stats.linesearch = linesearch

v = MisI ? r2 : solver.v
rk = linesearch ? solver.rk : r2

ϵM = eps(T)
ctol = conlim > 0 ? 1 / conlim : zero(T)
Expand All @@ -175,6 +179,8 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm
kcopy!(n, r1, b) # r1 ← b
end

linesearch && kcopy!(n, rk, r1) # rk ← r1

# Initialize Lanczos process.
# β₁ M v₁ = b.
kcopy!(n, r2, r1) # r2 ← r1
Expand All @@ -183,18 +189,6 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm
rNorm = knorm_elliptic(n, r2, r1) # = ‖r‖
history && push!(rNorms, rNorm)

if rNorm == 0
stats.niter = 0
stats.solved, stats.inconsistent = true, false
stats.timer = start_time |> ktimer
stats.status = "x is a zero-residual solution"
history && push!(rNorms, zero(T))
history && push!(ArNorms, zero(T))
history && push!(Aconds, zero(T))
warm_start && kaxpy!(n, one(FC), Δx, x)
solver.warm_start = false
return solver
end

β₁ = kdotr(m, r1, v)
β₁ < 0 && error("Preconditioner is not positive definite")
Expand All @@ -208,7 +202,7 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm
history && push!(Aconds, zero(T))
warm_start && kaxpy!(n, one(FC), Δx, x)
solver.warm_start = false
linesearch && kcopy!(n, x, b) # x ← b
stats.nonposi_curv = true
return solver
end
β₁ = sqrt(β₁)
Expand Down Expand Up @@ -263,7 +257,7 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm
# Generate next Lanczos vector.
mul!(y, A, v)
λ 0 && kaxpy!(n, λ, v, y) # (y = y + λ * v)
kscal!(n, one(FC) / β, y)
kscal!(n, one(FC) / β, y) # (y = y / β)
iter 2 && kaxpy!(n, -β / oldβ, r1, y) # (y = y - β / oldβ * r1)

α = kdotr(n, v, y) / β
Expand Down Expand Up @@ -311,6 +305,9 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm
stats.status = "nonpositive curvature"
iter == 1 && kcopy!(n, x, r)
solver.warm_start = false
# when we use the linesearch and encounter negative curvature, we return the last residual rk
kcopy!(n, x, rk) # x ← rk
stats.nonposi_curv = true
return solver
end
end
Expand All @@ -323,6 +320,16 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm
ϕ = cs * ϕbar
ϕbar = sn * ϕbar

if linesearch
# calculating the residual rk = sn*sn * rk - ϕbar * cs * v
sn2 = sn * sn
kscal!(n, sn2, rk ) # rk = sn2 * rk
ϕ_c = -ϕbar * cs
kaxpy!(n, ϕ_c, v, rk) # rk = rk + ϕ_c * v
rk = sn*sn * rk - ϕbar * cs * v
end


# Final update of w.
kscal!(n, one(FC) / γ, w)

Expand Down Expand Up @@ -412,6 +419,7 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm
user_requested_exit && (status = "user-requested exit")
overtimed && (status = "time limit exceeded")

stats.nonposi_curv = zero_resid
# Update x
warm_start && kaxpy!(n, one(FC), Δx, x)
solver.warm_start = false
Expand Down

0 comments on commit b60e77c

Please sign in to comment.