diff --git a/src/krylov_solvers.jl b/src/krylov_solvers.jl index 905eda930..c9c49a784 100644 --- a/src/krylov_solvers.jl +++ b/src/krylov_solvers.jl @@ -112,13 +112,14 @@ mutable struct MinresSolver{T,FC,S} <: KrylovSolver{T,FC,S} x :: S r1 :: S r2 :: S + rk :: S w1 :: S w2 :: S y :: S v :: S err_vec :: Vector{T} warm_start :: Bool - stats :: SimpleStats{T} + stats :: conStats{T} end function MinresSolver(kc::KrylovConstructor; window :: Int=5) @@ -131,13 +132,14 @@ function MinresSolver(kc::KrylovConstructor; window :: Int=5) x = similar(kc.vn) r1 = similar(kc.vn) r2 = similar(kc.vn) + rk = similar(kc.vn_empty) w1 = similar(kc.vn) w2 = similar(kc.vn) y = similar(kc.vn) v = similar(kc.vn_empty) err_vec = zeros(T, window) - stats = SimpleStats(0, false, false, T[], T[], T[], 0.0, "unknown") - solver = MinresSolver{T,FC,S}(m, n, Δx, x, r1, r2, w1, w2, y, v, err_vec, false, stats) + stats = conStats(0, false, false, false, false, T[], T[], T[], 0.0, "unknown") + solver = MinresSolver{T,FC,S}(m, n, Δx, x, r1, r2, rk, w1, w2, y, v, err_vec, false, stats) return solver end @@ -148,13 +150,14 @@ function MinresSolver(m, n, S; window :: Int=5) x = S(undef, n) r1 = S(undef, n) r2 = S(undef, n) + rk = S(undef, 0) w1 = S(undef, n) w2 = S(undef, n) y = S(undef, n) v = S(undef, 0) err_vec = zeros(T, window) - stats = SimpleStats(0, false, false, T[], T[], T[], 0.0, "unknown") - solver = MinresSolver{T,FC,S}(m, n, Δx, x, r1, r2, w1, w2, y, v, err_vec, false, stats) + stats = conStats(0, false, false, false, false, T[], T[], T[], 0.0, "unknown") + solver = MinresSolver{T,FC,S}(m, n, Δx, x, r1, r2, rk, w1, w2, y, v, err_vec, false, stats) return solver end diff --git a/src/krylov_stats.jl b/src/krylov_stats.jl index c383b7914..c9e87ea0a 100644 --- a/src/krylov_stats.jl +++ b/src/krylov_stats.jl @@ -47,6 +47,58 @@ function copyto!(dest :: SimpleStats, src :: SimpleStats) return dest end +""" +Type for storing statistics returned by Conjugate Methods. +Methods icludes: +- CG (TODO) +- CR (TODO) +- MINRES +The fields are as follows: +- niter +- solved +- nonposi_curv: when a non-positive curvature is detected +- linesearch: when a line search is performed +- inconsistent +- residuals +- Aresiduals +- Acond +- timer +- status +""" +mutable struct conStats{T} <: KrylovStats{T} + niter :: Int + solved :: Bool + nonposi_curv :: Bool + linesearch :: Bool + inconsistent :: Bool + residuals :: Vector{T} + Aresiduals :: Vector{T} + Acond :: Vector{T} + timer :: Float64 + status :: String +end + +function reset!(stats :: conStats) + empty!(stats.residuals) + empty!(stats.Aresiduals) + empty!(stats.Acond) +end + +function copyto!(dest :: conStats, src :: conStats) + dest.niter = src.niter + dest.solved = src.solved + dest.nonposi_curv = src.nonposi_curv + dest.linesearch = src.linesearch + dest.inconsistent = src.inconsistent + dest.residuals = copy(src.residuals) + dest.Aresiduals = copy(src.Aresiduals) + dest.Acond = copy(src.Acond) + dest.timer = src.timer + dest.status = src.status + return dest +end + + """ Type for storing statistics returned by LSMR. The fields are as follows: diff --git a/src/minres.jl b/src/minres.jl index a955eee20..bf6824a9a 100644 --- a/src/minres.jl +++ b/src/minres.jl @@ -19,7 +19,7 @@ # Brussels, Belgium, June 2015. # Montréal, August 2015. # -# Liu, Yang & Roosta, Fred. (2022). A Newton-MR algorithm with complexity guarantees for nonconvex smooth unconstrained optimization. 10.48550/arXiv.2208.07095. +# Liu, Yang, and Fred Roosta. "MINRES: from negative curvature detection to monotonicity properties." SIAM Journal on Optimization 32, no. 4 (2022): 2636-2661. export minres, minres! @@ -154,12 +154,16 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm # Set up workspace. allocate_if(!MisI, solver, :v, S, solver.x) # The length of v is n + allocate_if(linesearch, solver, :rk, S, solver.x) # The length of rk is n Δx, x, r1, r2, w1, w2, y = solver.Δx, solver.x, solver.r1, solver.r2, solver.w1, solver.w2, solver.y err_vec, stats = solver.err_vec, solver.stats warm_start = solver.warm_start rNorms, ArNorms, Aconds = stats.residuals, stats.Aresiduals, stats.Acond reset!(stats) + stats.linesearch = linesearch + v = MisI ? r2 : solver.v + rk = linesearch ? solver.rk : r2 ϵM = eps(T) ctol = conlim > 0 ? 1 / conlim : zero(T) @@ -175,6 +179,8 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm kcopy!(n, r1, b) # r1 ← b end + linesearch && kcopy!(n, rk, r1) # rk ← r1 + # Initialize Lanczos process. # β₁ M v₁ = b. kcopy!(n, r2, r1) # r2 ← r1 @@ -183,18 +189,6 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm rNorm = knorm_elliptic(n, r2, r1) # = ‖r‖ history && push!(rNorms, rNorm) - if rNorm == 0 - stats.niter = 0 - stats.solved, stats.inconsistent = true, false - stats.timer = start_time |> ktimer - stats.status = "x is a zero-residual solution" - history && push!(rNorms, zero(T)) - history && push!(ArNorms, zero(T)) - history && push!(Aconds, zero(T)) - warm_start && kaxpy!(n, one(FC), Δx, x) - solver.warm_start = false - return solver - end β₁ = kdotr(m, r1, v) β₁ < 0 && error("Preconditioner is not positive definite") @@ -208,7 +202,7 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm history && push!(Aconds, zero(T)) warm_start && kaxpy!(n, one(FC), Δx, x) solver.warm_start = false - linesearch && kcopy!(n, x, b) # x ← b + stats.nonposi_curv = true return solver end β₁ = sqrt(β₁) @@ -263,7 +257,7 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm # Generate next Lanczos vector. mul!(y, A, v) λ ≠ 0 && kaxpy!(n, λ, v, y) # (y = y + λ * v) - kscal!(n, one(FC) / β, y) + kscal!(n, one(FC) / β, y) # (y = y / β) iter ≥ 2 && kaxpy!(n, -β / oldβ, r1, y) # (y = y - β / oldβ * r1) α = kdotr(n, v, y) / β @@ -311,6 +305,9 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm stats.status = "nonpositive curvature" iter == 1 && kcopy!(n, x, r) solver.warm_start = false + # when we use the linesearch and encounter negative curvature, we return the last residual rk + kcopy!(n, x, rk) # x ← rk + stats.nonposi_curv = true return solver end end @@ -323,6 +320,16 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm ϕ = cs * ϕbar ϕbar = sn * ϕbar + if linesearch + # calculating the residual rk = sn*sn * rk - ϕbar * cs * v + sn2 = sn * sn + kscal!(n, sn2, rk ) # rk = sn2 * rk + ϕ_c = -ϕbar * cs + kaxpy!(n, ϕ_c, v, rk) # rk = rk + ϕ_c * v + rk = sn*sn * rk - ϕbar * cs * v + end + + # Final update of w. kscal!(n, one(FC) / γ, w) @@ -412,6 +419,7 @@ kwargs_minres = (:M, :ldiv, :linesearch ,:λ, :atol, :rtol, :etol, :conlim, :itm user_requested_exit && (status = "user-requested exit") overtimed && (status = "time limit exceeded") + stats.nonposi_curv = zero_resid # Update x warm_start && kaxpy!(n, one(FC), Δx, x) solver.warm_start = false