diff --git a/src/krylov_solvers.jl b/src/krylov_solvers.jl index 2cc3197f5..37be1373d 100644 --- a/src/krylov_solvers.jl +++ b/src/krylov_solvers.jl @@ -8,7 +8,7 @@ GmresSolver, FomSolver, GpmrSolver, FgmresSolver export solve!, solution, nsolution, statistics, issolved, issolved_primal, issolved_dual, niterations, Aprod, Atprod, Bprod, warm_start! -import Base.size +import Base.size, Base.sizeof const KRYLOV_SOLVERS = Dict( :cg => :CgSolver , @@ -1885,6 +1885,50 @@ for (KS, fun, nsol, nA, nAt, warm_start) in [ end end +function ksizeof(attribute) + if isa(attribute, AbstractVector) && !isempty(attribute) + # All vectors inside a vector have the same size in Krylov.jl + size_attribute = length(attribute) * ksizeof(attribute[1]) + else + size_attribute = sizeof(attribute) + end + return size_attribute +end + +function sizeof(solver :: KrylovSolver) + workspace = typeof(solver) + nfields = fieldcount(workspace) + storage = 0 + for i = 1:nfields-1 + field_i = getfield(solver, i) + size_i = ksizeof(field_i) + storage += size_i + end + return storage +end + +function val_metric(val::Int) + metric = "bytes" + if val ≥ 1024 + val /= 1024 + metric = "KB" + if val ≥ 1024 + val /= 1024 + metric = "MB" + if val ≥ 1024 + val /= 1024 + metric = "GB" + if val ≥ 1024 + val /= 1024 + metric = "TB" + end + end + end + val = round(val, digits=2) + end + return string(val, " ", metric) +end + """ show(io, solver; show_stats=true) @@ -1892,38 +1936,37 @@ Statistics of `solver` are displayed if `show_stats` is set to true. """ function show(io :: IO, solver :: KrylovSolver{T,FC,S}; show_stats :: Bool=true) where {T <: AbstractFloat, FC <: FloatOrComplex{T}, S <: DenseVector{FC}} workspace = typeof(solver) - name_solver = workspace.name.wrapper - l1 = max(length(string(name_solver)), 10) # length("warm_start") = 10 - l2 = length(string(S)) + 8 # length("Vector{}") = 8 + name_solver = string(workspace.name.name) + nbytes = sizeof(solver) + storage = val_metric(nbytes) architecture = S <: Vector ? "CPU" : "GPU" - format = Printf.Format("│%$(l1)s│%$(l2)s│%18s│\n") - format2 = Printf.Format("│%$(l1+1)s│%$(l2)s│%18s│\n") - @printf(io, "┌%s┬%s┬%s┐\n", "─"^l1, "─"^l2, "─"^18) - Printf.format(io, format, name_solver, "Precision: $FC", "Architecture: $architecture") - @printf(io, "├%s┼%s┼%s┤\n", "─"^l1, "─"^l2, "─"^18) - Printf.format(io, format, "Attribute", "Type", "Size") - @printf(io, "├%s┼%s┼%s┤\n", "─"^l1, "─"^l2, "─"^18) - for i=3:fieldcount(workspace)-1 # show m, n and stats seperately - type_i = fieldtype(workspace, i) + l1 = max(length(name_solver) + 8, length(string(FC)) + 11) # length("solver: ") = 8 and length("Precision: ") = 11 + l2 = max(ndigits(solver.m) + 7, length(architecture) + 14, length(string(S)) + 8) # length("Vector{}") = 8, # length("Architecture: ") = 14 and length("nrows: ") = 7 + l3 = max(ndigits(solver.n) + 7, length(storage) + 9, 13) # length("Size in bytes") = 13, length("Storage: ") = 9 and length("cols: ") = 7 + format = Printf.Format("│%$(l1)s│%$(l2)s│%$(l3)s│\n") + format2 = Printf.Format("│%$(l1+1)s│%$(l2)s│%$(l3)s│\n") + @printf(io, "┌%s┬%s┬%s┐\n", "─"^l1, "─"^l2, "─"^l3) + Printf.format(io, format, "Solver: $(name_solver)", "nrows: $(solver.m)", "ncols: $(solver.n)") + @printf(io, "├%s┼%s┼%s┤\n", "─"^l1, "─"^l2, "─"^l3) + Printf.format(io, format, "Precision: $FC", "Architecture: $architecture","Storage: $storage") + @printf(io, "├%s┼%s┼%s┤\n", "─"^l1, "─"^l2, "─"^l3) + Printf.format(io, format, "Attribute", "Type", "Size in bytes") + @printf(io, "├%s┼%s┼%s┤\n", "─"^l1, "─"^l2, "─"^l3) + for i=1:fieldcount(workspace)-1 # show stats seperately name_i = fieldname(workspace, i) - len = if type_i <: AbstractVector - field_i = getfield(solver, name_i) - ni = length(field_i) - if eltype(type_i) <: AbstractVector - "$(ni) x $(length(field_i[1]))" - else - length(field_i) - end - else - 0 - end + type_i = fieldtype(workspace, i) + field_i = getfield(solver, name_i) + size_i = ksizeof(field_i) if (name_i in [:w̅, :w̄, :d̅]) && (VERSION < v"1.8.0-DEV") - Printf.format(io, format2, string(name_i), type_i, len) + Printf.format(io, format2, string(name_i), type_i, size_i) else - Printf.format(io, format, string(name_i), type_i, len) + Printf.format(io, format, string(name_i), type_i, size_i) end end - @printf(io, "└%s┴%s┴%s┘\n","─"^l1,"─"^l2,"─"^18) - show_stats && show(io, solver.stats) + @printf(io, "└%s┴%s┴%s┘\n","─"^l1,"─"^l2,"─"^l3) + if show_stats + @printf(io, "\n") + show(io, solver.stats) + end return nothing end