Skip to content

Commit

Permalink
Better display of Krylov solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Oct 11, 2022
1 parent de4b630 commit 38bfc22
Showing 1 changed file with 71 additions and 28 deletions.
99 changes: 71 additions & 28 deletions src/krylov_solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ GmresSolver, FomSolver, GpmrSolver, FgmresSolver
export solve!, solution, nsolution, statistics, issolved, issolved_primal, issolved_dual,
niterations, Aprod, Atprod, Bprod, warm_start!

import Base.size
import Base.size, Base.sizeof

const KRYLOV_SOLVERS = Dict(
:cg => :CgSolver ,
Expand Down Expand Up @@ -1885,45 +1885,88 @@ for (KS, fun, nsol, nA, nAt, warm_start) in [
end
end

function ksizeof(attribute)
if isa(attribute, AbstractVector) && !isempty(attribute)
# All vectors inside a vector have the same size in Krylov.jl
size_attribute = length(attribute) * ksizeof(attribute[1])
else
size_attribute = sizeof(attribute)
end
return size_attribute
end

function sizeof(solver :: KrylovSolver)
workspace = typeof(solver)
nfields = fieldcount(workspace)
storage = 0
for i = 1:nfields-1
field_i = getfield(solver, i)
size_i = ksizeof(field_i)
storage += size_i
end
return storage
end

function val_metric(val::Int)
metric = "bytes"
if val 1024
val /= 1024
metric = "KB"
if val 1024
val /= 1024
metric = "MB"
if val 1024
val /= 1024
metric = "GB"
if val 1024
val /= 1024
metric = "TB"
end
end
end
val = round(val, digits=2)
end
return string(val, " ", metric)
end

"""
show(io, solver; show_stats=true)
Statistics of `solver` are displayed if `show_stats` is set to true.
"""
function show(io :: IO, solver :: KrylovSolver{T,FC,S}; show_stats :: Bool=true) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: DenseVector{FC}}
workspace = typeof(solver)
name_solver = workspace.name.wrapper
l1 = max(length(string(name_solver)), 10) # length("warm_start") = 10
l2 = length(string(S)) + 8 # length("Vector{}") = 8
name_solver = string(workspace.name.name)
nbytes = sizeof(solver)
storage = val_metric(nbytes)
architecture = S <: Vector ? "CPU" : "GPU"
format = Printf.Format("│%$(l1)s│%$(l2)s│%18s│\n")
format2 = Printf.Format("│%$(l1+1)s│%$(l2)s│%18s│\n")
@printf(io, "┌%s┬%s┬%s┐\n", ""^l1, ""^l2, ""^18)
Printf.format(io, format, name_solver, "Precision: $FC", "Architecture: $architecture")
@printf(io, "├%s┼%s┼%s┤\n", ""^l1, ""^l2, ""^18)
Printf.format(io, format, "Attribute", "Type", "Size")
@printf(io, "├%s┼%s┼%s┤\n", ""^l1, ""^l2, ""^18)
for i=3:fieldcount(workspace)-1 # show m, n and stats seperately
type_i = fieldtype(workspace, i)
l1 = max(length(name_solver) + 8, length(string(FC)) + 11) # length("solver: ") = 8 and length("Precision: ") = 11
l2 = max(ndigits(solver.m) + 7, length(architecture) + 14, length(string(S)) + 8) # length("Vector{}") = 8, # length("Architecture: ") = 14 and length("nrows: ") = 7
l3 = max(ndigits(solver.n) + 7, length(storage) + 9, 13) # length("Size in bytes") = 13, length("Storage: ") = 9 and length("cols: ") = 7
format = Printf.Format("│%$(l1)s│%$(l2)s│%$(l3)s│\n")
format2 = Printf.Format("│%$(l1+1)s│%$(l2)s│%$(l3)s│\n")
@printf(io, "┌%s┬%s┬%s┐\n", ""^l1, ""^l2, ""^l3)
Printf.format(io, format, "Solver: $(name_solver)", "nrows: $(solver.m)", "ncols: $(solver.n)")
@printf(io, "├%s┼%s┼%s┤\n", ""^l1, ""^l2, ""^l3)
Printf.format(io, format, "Precision: $FC", "Architecture: $architecture","Storage: $storage")
@printf(io, "├%s┼%s┼%s┤\n", ""^l1, ""^l2, ""^l3)
Printf.format(io, format, "Attribute", "Type", "Size in bytes")
@printf(io, "├%s┼%s┼%s┤\n", ""^l1, ""^l2, ""^l3)
for i=1:fieldcount(workspace)-1 # show stats seperately
name_i = fieldname(workspace, i)
len = if type_i <: AbstractVector
field_i = getfield(solver, name_i)
ni = length(field_i)
if eltype(type_i) <: AbstractVector
"$(ni) x $(length(field_i[1]))"
else
length(field_i)
end
else
0
end
type_i = fieldtype(workspace, i)
field_i = getfield(solver, name_i)
size_i = ksizeof(field_i)
if (name_i in [:w̅, :w̄, :d̅]) && (VERSION < v"1.8.0-DEV")
Printf.format(io, format2, string(name_i), type_i, len)
Printf.format(io, format2, string(name_i), type_i, size_i)
else
Printf.format(io, format, string(name_i), type_i, len)
Printf.format(io, format, string(name_i), type_i, size_i)
end
end
@printf(io, "└%s┴%s┴%s┘\n",""^l1,""^l2,""^18)
show_stats && show(io, solver.stats)
@printf(io, "└%s┴%s┴%s┘\n",""^l1,""^l2,""^l3)
if show_stats
@printf(io, "\n")
show(io, solver.stats)
end
return nothing
end

0 comments on commit 38bfc22

Please sign in to comment.