diff --git a/src/cgls.jl b/src/cgls.jl index dd7665291..04cc02538 100644 --- a/src/cgls.jl +++ b/src/cgls.jl @@ -50,7 +50,8 @@ function cgls(A :: AbstractLinearOperator, b :: AbstractVector{T}; r = copy(b) bNorm = @knrm2(m, r) # Marginally faster than norm(b); bNorm == 0 && return x, SimpleStats(true, false, [0.0], [0.0], "x = 0 is a zero-residual solution"); - s = A' * M * r; + Mr = M * r + s = A.tprod(Mr) p = copy(s); γ = @kdot(n, s, s) # Faster than γ = dot(s, s); iter = 0; @@ -71,7 +72,8 @@ function cgls(A :: AbstractLinearOperator, b :: AbstractVector{T}; while ! (solved || tired) q = A * p; - δ = @kdot(m, q, M * q) # Faster than α = γ / dot(q, q); + Mq = M * q + δ = @kdot(m, q, Mq) # Faster than α = γ / dot(q, q); λ > 0 && (δ += λ * @kdot(n, p, p)) α = γ / δ; @@ -84,14 +86,12 @@ function cgls(A :: AbstractLinearOperator, b :: AbstractVector{T}; @kaxpy!(n, α, p, x) # Faster than x = x + α * p; @kaxpy!(m, -α, q, r) # Faster than r = r - α * q; - s = A' * M * r; + Mr = M * r + s = A.tprod(Mr); λ > 0 && @kaxpy!(n, -λ, x, s) # s = A' * r - λ * x; γ_next = @kdot(n, s, s) # Faster than γ_next = dot(s, s); β = γ_next / γ; - @kscal!(n, β, p) - @kaxpy!(n, 1.0, s, p) # Faster than p = s + β * p; - # The combined BLAS calls tend to trigger some gc. - # BLAS.axpy!(n, 1.0, s, 1, BLAS.scal!(n, β, p, 1), 1); + @kaxpby!(n, 1.0, s, β, p) # Faster than p = s + β * p; γ = γ_next; rNorm = @knrm2(m, r) # Marginally faster than norm(r); ArNorm = sqrt(γ); diff --git a/test/test_alloc.jl b/test/test_alloc.jl index bf8e7f9ad..a7f2ed24c 100644 --- a/test/test_alloc.jl +++ b/test/test_alloc.jl @@ -83,3 +83,13 @@ expected_cgne_bytes = storage_cgne_bytes(n, m) (x, stats) = cgne(Au, c, M=N) # warmup actual_cgne_bytes = @allocated cgne(Au, c, M=N) @test actual_cgne_bytes ≤ 1.1 * expected_cgne_bytes + +# without preconditioner and with (Ap, Aᵀq) preallocated, CGLS needs: +# - 2 m-vectors: x, p +# - 1 n-vector: r +storage_cgls(n, m) = 2*m + n +storage_cgls_bytes(n, m) = 8 * storage_cgls(n, m) +expected_cgls_bytes = storage_cgls_bytes(n, m) +(x, stats) = cgls(Ao, b, M=M) # warmup +actual_cgls_bytes = @allocated cgls(Ao, b, M=M) +@test actual_cgls_bytes ≤ 1.1 * expected_cgls_bytes