Skip to content

Commit

Permalink
fix it all!
Browse files Browse the repository at this point in the history
  • Loading branch information
tmigot committed Feb 7, 2024
1 parent 3f9e200 commit aeb2b4e
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 12 deletions.
2 changes: 0 additions & 2 deletions docs/src/examples/cgls_lanczos_shift.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
# TODO

```@example cgls_lanczos_shift
using MatrixMarket, SuiteSparseMatrixCollection
using Krylov, LinearOperators
Expand Down
3 changes: 2 additions & 1 deletion src/cgls_lanczos_shift.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ kwargs_cgls_lanczos_shift = (:M, :ldiv, :check_curvature, :atol, :rtol, :itmax,
stats.niter = 0
stats.solved = true
stats.timer = ktimer(start_time)
status = "x = 0 is a zero-residual solution"
stats.status = "x = 0 is a zero-residual solution"
return solver
end

Expand Down Expand Up @@ -275,6 +275,7 @@ kwargs_cgls_lanczos_shift = (:M, :ldiv, :check_curvature, :atol, :rtol, :itmax,
stats.niter = iter
stats.solved = solved
stats.timer = ktimer(start_time)
stats.status = status
return solver
end
end
2 changes: 1 addition & 1 deletion test/test_allocations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@
actual_cgls_lanczos_shift_bytes = @allocated cgls_lanczos_shift(Ao, b, shifts)
@test expected_cgls_lanczos_shift_bytes actual_cgls_lanczos_shift_bytes 1.02 * expected_cgls_lanczos_shift_bytes

solver = CglsLanczosShiftSolver(Ao, b)
solver = CglsLanczosShiftSolver(Ao, b, length(shifts))
cgls_lanczos_shift!(solver, Ao, b, shifts) # warmup
inplace_cgls_lanczos_shift_bytes = @allocated cgls_lanczos_shift!(solver, Ao, b, shifts)
@test inplace_cgls_lanczos_shift_bytes == 0
Expand Down
2 changes: 1 addition & 1 deletion test/test_cgls_lanczos_shift.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ end
for xi x
@test norm(xi) == 0
end
@test status == "x = 0 is a zero-residual solution"
@test stats.status == "x = 0 is a zero-residual solution"

#=
# Not implemented
Expand Down
6 changes: 5 additions & 1 deletion test/test_mp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
x, y, _ = @eval $fn($A, $B, $b, $c)
elseif fn in (:lnlq, :craig, :craigmr)
x, y, _ = @eval $fn($A, $b)
elseif fn == :cg_lanczos_shift
elseif fn in (:cg_lanczos_shift, :cgls_lanczos_shift)
x, _ = @eval $fn($A, $b, $shifts)
else
x, _ = @eval $fn($A, $b)
Expand All @@ -42,6 +42,10 @@
@test norm((A - I) * x[1] - b) Κ * (atol + norm(b) * rtol)
@test norm((A + I) * x[2] - b) Κ * (atol + norm(b) * rtol)
@test eltype(x) == Vector{FC}
elseif fn == :cgls_lanczos_shift
@test norm(A' * (b - A * x[1]) + x[1]) Κ * (atol + norm(b) * rtol)
@test norm(A' * (b - A * x[2]) - x[2]) Κ * (atol + norm(b) * rtol)
@test eltype(x) == Vector{FC}
else
@test norm(A * x - b) Κ * (atol + norm(b) * rtol)
@test eltype(x) == FC
Expand Down
13 changes: 7 additions & 6 deletions test/test_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ function test_solvers(FC)
m = div(n, 2)
Au = A[1:m,:] # Dimension m x n
Ao = A[:,1:m] # Dimension n x m
b = Ao * ones(FC, m) # Dimension n
c = Au * ones(FC, n) # Dimension m
b = Ao * ones(FC, m) # Dimension m
c = Au * ones(FC, n) # Dimension n
mem = 10
shifts = [1.0; 2.0; 3.0; 4.0; 5.0]
nshifts = 5
Expand Down Expand Up @@ -70,8 +70,8 @@ function test_solvers(FC)
method == :cg_lanczos_shift && @test_throws ErrorException("solver.nshifts = $(solver.nshifts) is inconsistent with length(shifts) = $(length(shifts2))") solve!(solver, A, b, shifts2)
method (:cgne, :crmr, :lnlq, :craig, :craigmr) && @test_throws ErrorException("(solver.m, solver.n) = ($(solver.m), $(solver.n)) is inconsistent with size(A) = ($m2, $n2)") solve!(solver, Au2, c2)
method (:cgls, :crls, :lslq, :lsqr, :lsmr) && @test_throws ErrorException("(solver.m, solver.n) = ($(solver.m), $(solver.n)) is inconsistent with size(A) = ($n2, $m2)") solve!(solver, Ao2, b2)
hod == :cg_lanczos_shift && @test_throws ErrorException("(solver.m, solver.n) = ($(solver.m), $(solver.n)) is inconsistent with size(A) = ($n2, $n2)") solve!(solver, Ao2, b2, shifts2)
method == :cg_lanczos_shift && @test_throws ErrorException("solver.nshifts = $(solver.nshifts) is inconsistent with length(shifts) = $(length(shifts2))") solve!(solver, A, b, shifts2)
method == :cgls_lanczos_shift && @test_throws ErrorException("(solver.m, solver.n) = ($(solver.m), $(solver.n)) is inconsistent with size(A) = ($n2, $m2)") solve!(solver, Ao2, b2, shifts2)
method == :cgls_lanczos_shift && @test_throws ErrorException("solver.nshifts = $(solver.nshifts) is inconsistent with length(shifts) = $(length(shifts2))") solve!(solver, Ao, b, shifts2)
method (:bilqr, :trilqr) && @test_throws ErrorException("(solver.m, solver.n) = ($(solver.m), $(solver.n)) is inconsistent with size(A) = ($n2, $n2)") solve!(solver, A2, b2, b2)
method == :gpmr && @test_throws ErrorException("(solver.m, solver.n) = ($(solver.m), $(solver.n)) is inconsistent with size(A) = ($n2, $m2)") solve!(solver, Ao2, Au2, b2, c2)
method (:tricg, :trimr) && @test_throws ErrorException("(solver.m, solver.n) = ($(solver.m), $(solver.n)) is inconsistent with size(A) = ($n2, $m2)") solve!(solver, Ao2, b2, c2)
Expand All @@ -87,6 +87,7 @@ function test_solvers(FC)
method == :cg_lanczos_shift && solve!(solver, A, b, shifts, timemax=timemax)
method (:cgne, :crmr, :lnlq, :craig, :craigmr) && solve!(solver, Au, c, timemax=timemax)
method (:cgls, :crls, :lslq, :lsqr, :lsmr) && solve!(solver, Ao, b, timemax=timemax)
method == :cgls_lanczos_shift && solve!(solver, Ao, b, shifts, timemax=timemax)
method (:bilqr, :trilqr) && solve!(solver, A, b, b, timemax=timemax)
method == :gpmr && solve!(solver, Ao, Au, b, c, timemax=timemax)
method (:tricg, :trimr) && solve!(solver, Au, c, b, timemax=timemax)
Expand Down Expand Up @@ -125,8 +126,8 @@ function test_solvers(FC)
(nsolution == 2) && (@test solution(solver, 2) == solver.y)
end

if method (:cgls, :crls, :lslq, :lsqr, :lsmr)
solve!(solver, Ao, b)
if method (:cgls, :crls, :lslq, :lsqr, :lsmr, :cgls_lanczos_shift)
method == :cgls_lanczos_shift ? solve!(solver, Ao, b, shifts) : solve!(solver, Ao, b)
niter = niterations(solver)
@test Aprod(solver) == niter
@test Atprod(solver) == niter
Expand Down

0 comments on commit aeb2b4e

Please sign in to comment.