Skip to content

Commit

Permalink
Improve roots_quadratic
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Sep 27, 2022
1 parent fb509d6 commit 1a582cd
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 95 deletions.
2 changes: 1 addition & 1 deletion src/cg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ function cg!(solver :: CgSolver{T,FC,S}, A, b :: AbstractVector{FC};
α = γ / pAp

# Compute step size to boundary if applicable.
σ = radius > 0 ? maximum(to_boundary(x, p, radius, dNorm2=pNorm²)) : α
σ = radius > 0 ? maximum(to_boundary(n, x, p, radius, dNorm2=pNorm²)) : α

kdisplay(iter, verbose) && @printf("%8.1e %8.1e %8.1e\n", pAp, α, σ)

Expand Down
2 changes: 1 addition & 1 deletion src/cgls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ function cgls!(solver :: CglsSolver{T,FC,S}, A, b :: AbstractVector{FC};
α = γ / δ

# if a trust-region constraint is give, compute step to the boundary
σ = radius > 0 ? maximum(to_boundary(x, p, radius)) : α
σ = radius > 0 ? maximum(to_boundary(n, x, p, radius)) : α
if (radius > 0) &> σ)
α = σ
on_boundary = true
Expand Down
4 changes: 2 additions & 2 deletions src/cr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ function cr!(solver :: CrSolver{T,FC,S}, A, b :: AbstractVector{FC};
(verbose > 0) && @printf("radius = %8.1e > 0 and ‖x‖ = %8.1e\n", radius, xNorm)
# find t1 > 0 and t2 < 0 such that ‖x + ti * p‖² = radius² (i = 1, 2)
xNorm² = xNorm * xNorm
t = to_boundary(x, p, radius; flip = false, xNorm2 = xNorm², dNorm2 = pNorm²)
t = to_boundary(n, x, p, radius; flip = false, xNorm2 = xNorm², dNorm2 = pNorm²)
t1 = maximum(t) # > 0
t2 = minimum(t) # < 0
tr = maximum(to_boundary(x, r, radius; flip = false, xNorm2 = xNorm², dNorm2 = rNorm²))
tr = maximum(to_boundary(n, x, r, radius; flip = false, xNorm2 = xNorm², dNorm2 = rNorm²))
(verbose > 0) && @printf("t1 = %8.1e, t2 = %8.1e and tr = %8.1e\n", t1, t2, tr)

if abspAp γ * pNorm * @knrm2(n, q) # pᴴAp ≃ 0
Expand Down
4 changes: 2 additions & 2 deletions src/crls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ function crls!(solver :: CrlsSolver{T,FC,S}, A, b :: AbstractVector{FC};
p = Ar # p = Aᴴr
pNorm² = ArNorm * ArNorm
mul!(q, Aᴴ, s)
α = min(ArNorm^2 / γ, maximum(to_boundary(x, p, radius, flip = false, dNorm2 = pNorm²))) # the quadratic is minimal in the direction Aᴴr for α = ‖Ar‖²/γ
α = min(ArNorm^2 / γ, maximum(to_boundary(n, x, p, radius, flip = false, dNorm2 = pNorm²))) # the quadratic is minimal in the direction Aᴴr for α = ‖Ar‖²/γ
else
pNorm² = pNorm * pNorm
σ = maximum(to_boundary(x, p, radius, flip = false, dNorm2 = pNorm²))
σ = maximum(to_boundary(n, x, p, radius, flip = false, dNorm2 = pNorm²))
if α σ
α = σ
on_boundary = true
Expand Down
78 changes: 41 additions & 37 deletions src/krylov_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ function sym_givens(a :: Complex{T}, b :: Complex{T}) where T <: AbstractFloat
return (c, s, ρ)
end

@inline sym_givens(a :: Complex{T}, b :: T) where T <: AbstractFloat = sym_givens(a, Complex{T}(b))
@inline sym_givens(a :: T, b :: Complex{T}) where T <: AbstractFloat = sym_givens(Complex{T}(a), b)
sym_givens(a :: Complex{T}, b :: T) where T <: AbstractFloat = sym_givens(a, Complex{T}(b))
sym_givens(a :: T, b :: Complex{T}) where T <: AbstractFloat = sym_givens(Complex{T}(a), b)

"""
roots = roots_quadratic(q₂, q₁, q₀; nitref)
Expand All @@ -111,19 +111,19 @@ function roots_quadratic(q₂ :: T, q₁ :: T, q₀ :: T;
# Case where q(x) is linear.
if q₂ == zero(T)
if q₁ == zero(T)
root = tuple(zero(T))
q₀ == zero(T) || (root = tuple())
q₀ == zero(T) || error("The quadratic `q` doesn't have real roots.")
root = zero(T)
else
root = tuple(-q₀ / q₁)
root = -q₀ / q₁
end
return root
return (root, root)
end

# Case where q(x) is indeed quadratic.
rhs = eps(T) * q₁ * q₁
if abs(q₀ * q₂) > rhs
ρ = q₁ * q₁ - 4 * q₂ * q₀
ρ < 0 && return tuple()
ρ < 0 && return error("The quadratic `q` doesn't have real roots.")
d = -(q₁ + copysign(sqrt(ρ), q₁)) / 2
root1 = d / q₂
root2 = q₀ / d
Expand All @@ -150,36 +150,6 @@ function roots_quadratic(q₂ :: T, q₁ :: T, q₀ :: T;
return (root1, root2)
end


"""
roots = to_boundary(x, d, radius; flip, xNorm2, dNorm2)
Given a trust-region radius `radius`, a vector `x` lying inside the
trust-region and a direction `d`, return `σ1` and `σ2` such that
‖x + σi d‖ = radius, i = 1, 2
in the Euclidean norm. If known, ‖x‖² may be supplied in `xNorm2`.
If `flip` is set to `true`, `σ1` and `σ2` are computed such that
‖x - σi d‖ = radius, i = 1, 2.
"""
function to_boundary(x :: Vector{T}, d :: Vector{T},
radius :: T; flip :: Bool=false, xNorm2 :: T=zero(T), dNorm2 :: T=zero(T)) where T <: Number
radius > 0 || error("radius must be positive")

# ‖d‖² σ² + (xᴴd + dᴴx) σ + (‖x‖² - radius²).
rxd = real(dot(x, d))
flip && (rxd = -rxd)
dNorm2 == zero(T) && (dNorm2 = dot(d, d))
dNorm2 == zero(T) && error("zero direction")
xNorm2 == zero(T) && (xNorm2 = dot(x, x))
(xNorm2 radius * radius) || error(@sprintf("outside of the trust region: ‖x‖²=%7.1e, Δ²=%7.1e", xNorm2, radius * radius))
roots = roots_quadratic(dNorm2, 2 * rxd, xNorm2 - radius * radius)
return roots # `σ1` and `σ2`
end

"""
s = vec2str(x; ndisp)
Expand Down Expand Up @@ -357,3 +327,37 @@ end
macro kref!(n, x, y, c, s)
return esc(:(reflect!($x, $y, $c, $s)))
end

"""
roots = to_boundary(n, x, d, radius; flip, xNorm2, dNorm2)
Given a trust-region radius `radius`, a vector `x` lying inside the
trust-region and a direction `d`, return `σ1` and `σ2` such that
‖x + σi d‖ = radius, i = 1, 2
in the Euclidean norm.
`n` is the length of vectors `x` and `d`.
If known, ‖x‖² and ‖d‖² may be supplied with `xNorm2` and `dNorm2`.
If `flip` is set to `true`, `σ1` and `σ2` are computed such that
‖x - σi d‖ = radius, i = 1, 2.
"""
function to_boundary(n :: Int, x :: Vector{T}, d :: Vector{T}, radius :: T; flip :: Bool=false, xNorm2 :: T=zero(T), dNorm2 :: T=zero(T)) where T <: FloatOrComplex
radius > 0 || error("radius must be positive")

# ‖d‖² σ² + (xᴴd + dᴴx) σ + (‖x‖² - Δ²).
rxd = @kdotr(n, x, d)
flip && (rxd = -rxd)
dNorm2 == zero(T) && (dNorm2 = @kdot(n, d, d))
dNorm2 == zero(T) && error("zero direction")
xNorm2 == zero(T) && (xNorm2 = @kdot(n, x, x))
radius2 = radius * radius
(xNorm2 radius2) || error(@sprintf("outside of the trust region: ‖x‖²=%7.1e, Δ²=%7.1e", xNorm2, radius2))

# q₂ = ‖d‖², q₁ = xᴴd + dᴴx, q₀ = ‖x‖² - Δ²
# ‖x‖² ≤ Δ² ⟹ (q₁)² - 4 * q₂ * q₀ ≥ 0
roots = roots_quadratic(dNorm2, 2 * rxd, xNorm2 - radius2)
return roots # `σ1` and `σ2`
end
2 changes: 1 addition & 1 deletion src/lsmr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ function lsmr!(solver :: LsmrSolver{T,FC,S}, A, b :: AbstractVector{FC};
# the step ϕ/ρ is not necessarily positive
σ = ζ /* ρbar)
if radius > 0
t1, t2 = to_boundary(x, hbar, radius)
t1, t2 = to_boundary(n, x, hbar, radius)
tmax, tmin = max(t1, t2), min(t1, t2)
on_boundary = σ > tmax || σ < tmin
σ = σ > 0 ? min(σ, tmax) : max(σ, tmin)
Expand Down
2 changes: 1 addition & 1 deletion src/lsqr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ function lsqr!(solver :: LsqrSolver{T,FC,S}, A, b :: AbstractVector{FC};
# the step ϕ/ρ is not necessarily positive
σ = ϕ / ρ
if radius > 0
t1, t2 = to_boundary(x, w, radius)
t1, t2 = to_boundary(n, x, w, radius)
tmax, tmin = max(t1, t2), min(t1, t2)
on_boundary = σ > tmax || σ < tmin
σ = σ > 0 ? min(σ, tmax) : max(σ, tmin)
Expand Down
81 changes: 31 additions & 50 deletions test/test_aux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,102 +36,83 @@
@testset "roots_quadratic" begin
# test roots of a quadratic
roots = Krylov.roots_quadratic(0.0, 0.0, 0.0)
@test length(roots) == 1
@test roots[1] == 0.0
@test roots[2] == 0.0

roots = Krylov.roots_quadratic(0.0, 0.0, 1.0)
@test length(roots) == 0
@test_throws ErrorException Krylov.roots_quadratic(0.0, 0.0, 1.0)

roots = Krylov.roots_quadratic(0.0, 3.14, -1.0)
@test length(roots) == 1
@test roots[1] == 1.0 / 3.14
@test roots[2] == 1.0 / 3.14

roots = Krylov.roots_quadratic(1.0, 0.0, 1.0)
@test length(roots) == 0
@test_throws ErrorException Krylov.roots_quadratic(1.0, 0.0, 1.0)

roots = Krylov.roots_quadratic(1.0, 0.0, 0.0)
@test length(roots) == 2
@test roots[1] == 0.0
@test roots[2] == 0.0

roots = Krylov.roots_quadratic(1.0, 3.0, 2.0)
@test length(roots) == 2
@test roots[1] -2.0
@test roots[2] -1.0

roots = Krylov.roots_quadratic(1.0e+8, 1.0, 1.0)
@test length(roots) == 0
@test_throws ErrorException Krylov.roots_quadratic(1.0e+8, 1.0, 1.0)

# ill-conditioned quadratic
roots = Krylov.roots_quadratic(-1.0e-8, 1.0e+5, 1.0, nitref=0)
@test length(roots) == 2
@test roots[1] == 1.0e+13
@test roots[2] == 0.0

# iterative refinement is crucial!
roots = Krylov.roots_quadratic(-1.0e-8, 1.0e+5, 1.0, nitref=1)
@test length(roots) == 2
@test roots[1] == 1.0e+13
@test roots[2] == -1.0e-05

# not ill-conditioned quadratic
roots = Krylov.roots_quadratic(-1.0e-7, 1.0, 1.0, nitref=0)
@test length(roots) == 2
@test isapprox(roots[1], 1.0e+7, rtol=1.0e-6)
@test isapprox(roots[2], -1.0, rtol=1.0e-6)

roots = Krylov.roots_quadratic(-1.0e-7, 1.0, 1.0, nitref=1)
@test length(roots) == 2
@test isapprox(roots[1], 1.0e+7, rtol=1.0e-6)
@test isapprox(roots[2], -1.0, rtol=1.0e-6)

if VERSION v"1.8"
allocations = @allocated Krylov.roots_quadratic(0.0, 0.0, 0.0)
@test allocations == 0

allocations = @allocated Krylov.roots_quadratic(0.0, 0.0, 1.0)
@test allocations == 0

allocations = @allocated Krylov.roots_quadratic(0.0, 3.14, -1.0)
@test allocations == 0
allocations = @allocated Krylov.roots_quadratic(0.0, 0.0, 0.0)
@test allocations == 0

allocations = @allocated Krylov.roots_quadratic(1.0, 0.0, 1.0)
@test allocations == 0
allocations = @allocated Krylov.roots_quadratic(0.0, 3.14, -1.0)
@test allocations == 0

allocations = @allocated Krylov.roots_quadratic(1.0, 0.0, 0.0)
@test allocations == 0
allocations = @allocated Krylov.roots_quadratic(1.0, 0.0, 0.0)
@test allocations == 0

allocations = @allocated Krylov.roots_quadratic(1.0, 3.0, 2.0)
@test allocations == 0
allocations = @allocated Krylov.roots_quadratic(1.0, 3.0, 2.0)
@test allocations == 0

allocations = @allocated Krylov.roots_quadratic(1.0e+8, 1.0, 1.0)
@test allocations == 0
allocations = @allocated Krylov.roots_quadratic(-1.0e-8, 1.0e+5, 1.0, nitref=0)
@test allocations == 0

allocations = @allocated Krylov.roots_quadratic(-1.0e-8, 1.0e+5, 1.0, nitref=0)
@test allocations == 0
allocations = @allocated Krylov.roots_quadratic(-1.0e-8, 1.0e+5, 1.0, nitref=1)
@test allocations == 0

allocations = @allocated Krylov.roots_quadratic(-1.0e-8, 1.0e+5, 1.0, nitref=1)
@test allocations == 0
allocations = @allocated Krylov.roots_quadratic(-1.0e-7, 1.0, 1.0, nitref=0)
@test allocations == 0

allocations = @allocated Krylov.roots_quadratic(-1.0e-7, 1.0, 1.0, nitref=0)
@test allocations == 0

allocations = @allocated Krylov.roots_quadratic(-1.0e-7, 1.0, 1.0, nitref=1)
@test allocations == 0
end
allocations = @allocated Krylov.roots_quadratic(-1.0e-7, 1.0, 1.0, nitref=1)
@test allocations == 0
end

@testset "to_boundary" begin
# test trust-region boundary
x = ones(5)
d = ones(5); d[1:2:5] .= -1
@test_throws ErrorException Krylov.to_boundary(x, d, -1.0)
@test_throws ErrorException Krylov.to_boundary(x, d, 0.5)
@test_throws ErrorException Krylov.to_boundary(x, zeros(5), 1.0)
@test maximum(Krylov.to_boundary(x, d, 5.0)) 2.209975124224178
@test minimum(Krylov.to_boundary(x, d, 5.0)) -1.8099751242241782
@test maximum(Krylov.to_boundary(x, d, 5.0, flip=true)) 1.8099751242241782
@test minimum(Krylov.to_boundary(x, d, 5.0, flip=true)) -2.209975124224178
n = 5
x = ones(n)
d = ones(n); d[1:2:n] .= -1
@test_throws ErrorException Krylov.to_boundary(n, x, d, -1.0)
@test_throws ErrorException Krylov.to_boundary(n, x, d, 0.5)
@test_throws ErrorException Krylov.to_boundary(n, x, zeros(n), 1.0)
@test maximum(Krylov.to_boundary(n, x, d, 5.0)) 2.209975124224178
@test minimum(Krylov.to_boundary(n, x, d, 5.0)) -1.8099751242241782
@test maximum(Krylov.to_boundary(n, x, d, 5.0, flip=true)) 1.8099751242241782
@test minimum(Krylov.to_boundary(n, x, d, 5.0, flip=true)) -2.209975124224178
end

@testset "kzeros" begin
Expand Down

0 comments on commit 1a582cd

Please sign in to comment.