Skip to content

Commit

Permalink
Fix cgls_lanczos_shift.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Jan 17, 2025
1 parent 2e5ed14 commit 8d66d80
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 37 deletions.
18 changes: 8 additions & 10 deletions src/cgls_lanczos_shift.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,12 @@ kwargs_cgls_lanczos_shift = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose

# Set up workspace.
allocate_if(!MisI, solver, :v, S, solver.Mv) # The length of v is n
u_prev, utilde = solver.Mv_prev, solver.Mv_next
u = solver.u
v, u_prev, u, u_next = solver.Mv, solver.u_prev, solver.u, solver.u_next
x, p, σ, δhat = solver.x, solver.p, solver.σ, solver.δhat
ω, γ, rNorms, converged = solver.ω, solver.γ, solver.rNorms, solver.converged
not_cv, stats = solver.not_cv, solver.stats
rNorms_history, status = stats.residuals, stats.status
reset!(stats)
v = solver.v

# Initial state.
## Distribute x similarly to shifts.
Expand Down Expand Up @@ -198,16 +196,16 @@ kwargs_cgls_lanczos_shift = (:M, :ldiv, :atol, :rtol, :itmax, :timemax, :verbose
while ! (solved || tired || user_requested_exit || overtimed)

# Form next Lanczos vector.
mul!(utilde, A, v) # utildeₖ ← Avₖ
δ = kdotr(m, utilde, utilde) # δₖ = vₖᵀAᴴAvₖ
kaxpy!(m, -δ, u, utilde) # uₖ₊₁ = utildeₖ - δₖuₖ - βₖuₖ₋₁
kaxpy!(m, -β, u_prev, utilde)
mul!(v, Aᴴ, utilde) # vₖ₊₁ = Aᴴuₖ₊₁
mul!(u_next, A, v) # u_nextₖ ← Avₖ
δ = kdotr(m, u_next, u_next) # δₖ = vₖᵀAᴴAvₖ
kaxpy!(m, -δ, u, u_next) # uₖ₊₁ = u_nextₖ - δₖuₖ - βₖuₖ₋₁
kaxpy!(m, -β, u_prev, u_next)
mul!(v, Aᴴ, u_next) # vₖ₊₁ = Aᴴuₖ₊₁
β = knorm_elliptic(n, v, v) # βₖ₊₁ = vₖ₊₁ᵀ M vₖ₊₁
kscal!(n, one(FC) / β, v) # vₖ₊₁ ← vₖ₊₁ / βₖ₊₁
kscal!(m, one(FC) / β, utilde) # uₖ₊₁ = uₖ₊₁ / βₖ₊₁
kscal!(m, one(FC) / β, u_next) # uₖ₊₁ = uₖ₊₁ / βₖ₊₁
kcopy!(m, u_prev, u) # u_prev ← u
kcopy!(m, u, utilde) # u ← utilde
kcopy!(m, u, u_next) # u ← u_next

MisI ||= kdotr(n, v, v))
for i = 1 : nshifts
Expand Down
48 changes: 24 additions & 24 deletions src/krylov_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1719,24 +1719,24 @@ The outer constructors:
can be used to initialize this workspace.
"""
mutable struct CglsLanczosShiftSolver{T,FC,S} <: KrylovSolver{T,FC,S}
m :: Int
n :: Int
nshifts :: Int
Mv :: S
Mv_prev :: S
Mv_next :: S
u :: S
v :: S
x :: Vector{S}
p :: Vector{S}
σ :: Vector{T}
δhat :: Vector{T}
ω :: Vector{T}
γ :: Vector{T}
rNorms :: Vector{T}
converged :: BitVector
not_cv :: BitVector
stats :: LanczosShiftStats{T}
m :: Int
n :: Int
nshifts :: Int
Mv :: S
u_prev :: S
u_next :: S
u :: S
v :: S
x :: Vector{S}
p :: Vector{S}
σ :: Vector{T}
δhat :: Vector{T}
ω :: Vector{T}
γ :: Vector{T}
rNorms :: Vector{T}
converged :: BitVector
not_cv :: BitVector
stats :: LanczosShiftStats{T}
end

function CglsLanczosShiftSolver(kc::KrylovConstructor, nshifts)
Expand All @@ -1746,8 +1746,8 @@ function CglsLanczosShiftSolver(kc::KrylovConstructor, nshifts)
m = length(kc.vm)
n = length(kc.vn)
Mv = ksimilar(kc.vn)
Mv_prev = ksimilar(kc.vn)
Mv_next = ksimilar(kc.vn)
u_prev = ksimilar(kc.vm)
u_next = ksimilar(kc.vm)
u = ksimilar(kc.vm)
v = ksimilar(kc.vn_empty)
x = S[ksimilar(kc.vn) for i = 1 : nshifts]
Expand All @@ -1761,16 +1761,16 @@ function CglsLanczosShiftSolver(kc::KrylovConstructor, nshifts)
converged = BitVector(undef, nshifts)
not_cv = BitVector(undef, nshifts)
stats = LanczosShiftStats(0, false, Vector{T}[T[] for i = 1 : nshifts], indefinite, T(NaN), T(NaN), 0.0, "unknown")
solver = CglsLanczosShiftSolver{T,FC,S}(m, n, nshifts, Mv, Mv_prev, Mv_next, u, v, x, p, σ, δhat, ω, γ, rNorms, converged, not_cv, stats)
solver = CglsLanczosShiftSolver{T,FC,S}(m, n, nshifts, Mv, u_prev, u_next, u, v, x, p, σ, δhat, ω, γ, rNorms, converged, not_cv, stats)
return solver
end

function CglsLanczosShiftSolver(m, n, nshifts, S)
FC = eltype(S)
T = real(FC)
Mv = S(undef, n)
Mv_prev = S(undef, n)
Mv_next = S(undef, n)
u_prev = S(undef, m)
u_next = S(undef, m)
u = S(undef, m)
v = S(undef, 0)
x = S[S(undef, n) for i = 1 : nshifts]
Expand All @@ -1784,7 +1784,7 @@ function CglsLanczosShiftSolver(m, n, nshifts, S)
converged = BitVector(undef, nshifts)
not_cv = BitVector(undef, nshifts)
stats = LanczosShiftStats(0, false, Vector{T}[T[] for i = 1 : nshifts], indefinite, T(NaN), T(NaN), 0.0, "unknown")
solver = CglsLanczosShiftSolver{T,FC,S}(m, n, nshifts, Mv, Mv_prev, Mv_next, u, v, x, p, σ, δhat, ω, γ, rNorms, converged, not_cv, stats)
solver = CglsLanczosShiftSolver{T,FC,S}(m, n, nshifts, Mv, u_prev, u_next, u, v, x, p, σ, δhat, ω, γ, rNorms, converged, not_cv, stats)
return solver
end

Expand Down
6 changes: 3 additions & 3 deletions test/test_allocations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,12 +396,12 @@

@testset "CGLS-LANCZOS-SHIFT" begin
# CGLS-LANCZOS-SHIFT needs:
# - 3 n-vectors: Mv_prev, Mv, Mv_next
# - 1 m-vector: u
# - 1 n-vector: Mv
# - 3 m-vectors: u_prev, u, u_next
# - 2 (n*nshifts)-matrices: x, p
# - 5 nshifts-vectors: σ, δhat, ω, γ, rNorms
# - 3 nshifts-bitVector: converged, indefinite, not_cv
storage_cgls_lanczos_shift_bytes(m, n, nshifts) = nbits_FC * (3 * n + 1 * m + 2 * n * nshifts) + nbits_T * (5 * nshifts) + (3 * nshifts)
storage_cgls_lanczos_shift_bytes(m, n, nshifts) = nbits_FC * (1 * n + 3 * m + 2 * n * nshifts) + nbits_T * (5 * nshifts) + (3 * nshifts)

expected_cgls_lanczos_shift_bytes = storage_cgls_lanczos_shift_bytes(m, k, nshifts)
(x, stats) = cgls_lanczos_shift(Ao, b, shifts) # warmup
Expand Down

0 comments on commit 8d66d80

Please sign in to comment.