diff --git a/TypedSyntax/src/node.jl b/TypedSyntax/src/node.jl index 40b7cbf7..c83f07e0 100644 --- a/TypedSyntax/src/node.jl +++ b/TypedSyntax/src/node.jl @@ -191,12 +191,12 @@ function map_signature!(sig::TypedSyntaxNode, slotnames::Vector{Symbol}, slottyp kwdivider = 1 if havekws && slotnames[1] !== Symbol("#self#") kwdivider = findfirst(1:length(slotnames)) do i - slotnames[i] == Symbol("") && unwrapinternal(slottypes[i]) <: Function # this should be the parent function as an argument + slotnames[i] == Symbol("") && isa(unwrapinternal(slottypes[i]), Function) # this should be the parent function as an argument end if kwdivider === nothing kwdivider = 1 end - if length(slottypes) >= 2 && slotnames[2] == Symbol("") && (nt = unwrapinternal(slottypes[2])) <: NamedTuple + if length(slottypes) >= 2 && slotnames[2] == Symbol("") && (nt = unwrapinternal(slottypes[2]); isa(nt, Type)) && nt <: NamedTuple # Match kwargs argcontainer = children(last(children(sig))) offset = length(children(sig)) - 1 @@ -244,7 +244,7 @@ function map_signature!(sig::TypedSyntaxNode, slotnames::Vector{Symbol}, slottyp if kind(arg) == K"::" && length(children(arg)) == 2 arg = child(arg, 1) end - arg.typ = unwrapinternal(slottypes[idx]) + arg.typ = slottypes[idx] end # It's annoying to print the signature as `foo::typeof(foo)(a::Int)` @@ -276,7 +276,7 @@ function striparg(arg) end function unwrapinternal(@nospecialize(T)) - isa(T, Core.Const) && return Core.Typeof(T.val) + isa(T, Core.Const) && return T.val isa(T, Core.PartialStruct) && return T.typ return T end @@ -287,10 +287,10 @@ function gettyp(node2ssa, node, src) ssavaluetypes = src.ssavaluetypes::Vector{Any} if isa(stmt, Core.ReturnNode) arg = stmt.val - isa(arg, SSAValue) && return unwrapinternal(ssavaluetypes[arg.id]) - is_slot(arg) && return unwrapinternal((src.slottypes::Vector{Any})[arg.id]) + isa(arg, SSAValue) && return ssavaluetypes[arg.id] + is_slot(arg) && return (src.slottypes::Vector{Any})[arg.id] end - return unwrapinternal(ssavaluetypes[i]) + return ssavaluetypes[i] end Base.copy(tsd::TypedSyntaxData) = TypedSyntaxData(tsd.source, tsd.typedsource, tsd.raw, tsd.position, tsd.val, tsd.typ, tsd.runtime) @@ -608,7 +608,7 @@ function map_ssas_to_source(src::CodeInfo, mi::MethodInstance, rootnode::SyntaxN argmapping = typeof(rootnode)[] # temporary storage for (i, mapped, stmt) in zip(eachindex(mappings), mappings, src.code) empty!(argmapping) - if is_slot(stmt) || isa(stmt, SSAValue) + if is_slot(stmt) || isa(stmt, SSAValue) || isa(stmt, GlobalRef) append_targets_for_arg!(mapped, i, stmt) elseif isa(stmt, Core.ReturnNode) append_targets_for_line!(mapped, i, append_targets_for_arg!(argmapping, i, stmt.val)) @@ -626,16 +626,14 @@ function map_ssas_to_source(src::CodeInfo, mi::MethodInstance, rootnode::SyntaxN append_targets_for_arg!(mapped, i, stmt) filter_assignment_targets!(mapped, true) # match the RHS of assignments if length(mapped) == 1 - symtyps[only(mapped)] = unwrapinternal( - (is_slot(stmt) & have_slottypes) ? slottypes[(stmt::SlotType).id] : + symtyps[only(mapped)] = (is_slot(stmt) & have_slottypes) ? slottypes[(stmt::SlotType).id] : isa(stmt, SSAValue) ? ssavaluetypes[stmt.id] : #=literal=#typeof(stmt) - ) end # Now try to assign types to the LHS of the assignment append_targets_for_arg!(argmapping, i, lhs) filter_assignment_targets!(argmapping, false) # match the LHS of assignments if length(argmapping) == 1 - T = unwrapinternal(ssavaluetypes[i]) + T = ssavaluetypes[i] symtyps[only(argmapping)] = T end empty!(argmapping) @@ -738,7 +736,7 @@ function map_ssas_to_source(src::CodeInfo, mi::MethodInstance, rootnode::SyntaxN if isexpr(nextstmt, :call) f = nextstmt.args[1] if isa(f, GlobalRef) && f.mod == Base && f.name == :broadcasted - empty!(mapped) + # empty!(mapped) break elseif isa(f, GlobalRef) && f.mod == Base && f.name == :materialize && nextstmt.args[2] === SSAValue(i) push!(mappings[inext], node) @@ -793,14 +791,14 @@ function map_ssas_to_source(src::CodeInfo, mi::MethodInstance, rootnode::SyntaxN haskey(symtyps, t) && continue if skipped_parent(t) == node is_prec_assignment(node) && t == child(node, 1) && continue - symtyps[t] = unwrapinternal(if j > 0 + symtyps[t] = if j > 0 ssavaluetypes[j] elseif have_slottypes # We failed to find it as an SSAValue, it must have type assigned at function entry slottypes[arg.id] else nothing - end) + end break end end @@ -904,6 +902,9 @@ function skipped_parent(node::SyntaxNode) pnode === nothing && return node ppnode = pnode.parent if ppnode !== nothing && kind(pnode) ∈ KSet"... quote" # might need to add more things here + if kind(node) == K"Identifier" && kind(pnode) == K"quote" && kind(ppnode) == K"." && sourcetext(node) == "materialize" + return ppnode.parent + end return ppnode end return pnode diff --git a/TypedSyntax/src/show.jl b/TypedSyntax/src/show.jl index f3bfbe05..b4197b3e 100644 --- a/TypedSyntax/src/show.jl +++ b/TypedSyntax/src/show.jl @@ -87,22 +87,56 @@ end function is_show_annotation(@nospecialize(T); type_annotations::Bool, hide_type_stable::Bool) type_annotations || return false if isa(T, Core.Const) - T = typeof(T.val) + isa(T.val, Module) && return false + T = Core.Typeof(T.val) end isa(T, Type) || return false hide_type_stable || return true return isa(T, Type) && is_type_unstable(T) end +# Is the type equivalent to the source-text? +# We use `endswith` to handle module qualification +is_type_transparent(node, @nospecialize(T)) = endswith(replace(sprint(show, T), r"\s" => ""), replace(sourcetext(node), r"\s" => "")) + +function is_callfunc(node::MaybeTypedSyntaxNode, @nospecialize(T)) + thisnode = node + pnode = node.parent + while pnode !== nothing && kind(pnode) ∈ KSet"quote ." && pnode.parent !== nothing + thisnode = pnode + pnode = pnode.parent + end + if pnode !== nothing && kind(pnode) ∈ (K"call", K"curly") && ((is_infix_op_call(pnode) && is_operator(thisnode)) || thisnode === pnode.children[1]) + if isa(T, Core.Const) + T = T.val + end + if isa(T, Type) || isa(T, Function) + T === Colon() && sourcetext(node) == ":" && return true + return is_type_transparent(node, T) + end + end + return false +end + function type_annotation_mode(node, @nospecialize(T); type_annotations::Bool, hide_type_stable::Bool) kind(node) == K"return" && return false, "", "", "" + is_callfunc(node, T) && return false, "", "", "" type_annotate = is_show_annotation(T; type_annotations, hide_type_stable) pre = pre2 = post = "" if type_annotate - if T isa DataType && T <: Type && isassigned(T.parameters, 1) - if replace(sourcetext(node), r"\s" => "") == replace(sprint(show, T.parameters[1]), r"\s" => "") - return false, pre, pre2, post + # Try stripping Core.Const and Type{T} wrappers to check if we need to avoid `String::Type{String}` + # or `String::Core.Const(String)` annotations + S = nothing + if isa(T, Core.Const) + val = T.val + if isa(val, DataType) + S = val end + elseif isa(T, DataType) && T <: Type && isassigned(T.parameters, 1) + S = T.parameters[1] + end + if S !== nothing && is_type_transparent(node, S) + return false, pre, pre2, post end if kind(node) ∈ KSet":: where" || is_infix_op_call(node) || (is_prec_assignment(node) && kind(node) != K"=") pre, post = "(", ")" @@ -118,6 +152,9 @@ function show_annotation(io, @nospecialize(T), post, node, position; iswarn::Boo inlay_hints = get(io, :inlay_hints, nothing) print(io, post) + if isa(T, Core.Const) && isa(T.val, Type) + T = Type{T.val} + end T_str = string(T) if iswarn && is_type_unstable(T) color = is_small_union_or_tunion(T) ? :yellow : :red diff --git a/TypedSyntax/test/runtests.jl b/TypedSyntax/test/runtests.jl index d0115587..b1677ea9 100644 --- a/TypedSyntax/test/runtests.jl +++ b/TypedSyntax/test/runtests.jl @@ -2,6 +2,7 @@ using JuliaSyntax: JuliaSyntax, SyntaxNode, children, child, sourcetext, kind, @ using TypedSyntax: TypedSyntax, TypedSyntaxNode using Dates, InteractiveUtils, Test +has_name_typ(node, name::Symbol, @nospecialize(Ts::Tuple)) = kind(node) == K"Identifier" && node.val === name && node.typ in Ts has_name_typ(node, name::Symbol, @nospecialize(T)) = kind(node) == K"Identifier" && node.val === name && node.typ === T has_name_notyp(node, name::Symbol) = has_name_typ(node, name, nothing) @@ -49,7 +50,7 @@ include("test_module.jl") @test has_name_typ(child(body, 1), :x, Int) @test has_name_typ(child(body, 3, 2, 1), :x, Int) pi4 = child(body, 3, 2, 3) - @test kind(pi4) == K"call" && pi4.typ == typeof(π / 4) + @test kind(pi4) == K"call" && pi4.typ === Core.Const(π / 4) tsn = TypedSyntaxNode(TSN.has2xa, (Real,)) @test tsn.typ === Any sig, body = children(tsn) @@ -213,18 +214,18 @@ include("test_module.jl") tsn = TypedSyntaxNode(TSN.nestedgenerators, (Int, Int)) sig, body = children(tsn) @test kind(body) == K"generator" - @test body.typ <: Base.Iterators.Flatten + @test TypedSyntax.unwrapinternal(body.typ) <: Base.Iterators.Flatten tsn = TypedSyntaxNode(TSN.nestedgenerators, (Int,)) sig, body = children(tsn) @test kind(body) == K"generator" - @test body.typ <: Base.Iterators.Flatten + @test TypedSyntax.unwrapinternal(body.typ) <: Base.Iterators.Flatten tsn = TypedSyntaxNode(TSN.nestedexplicit, (Int,)) sig, body = children(tsn) @test kind(body) == K"comprehension" @test body.typ <: Vector node = child(body, 1) @test kind(node) == K"generator" - @test node.typ <: Base.Generator + @test TypedSyntax.unwrapinternal(node.typ) <: Base.Generator # Broadcasting tsn = TypedSyntaxNode(TSN.fbroadcast, (Vector{Int},)) @@ -237,9 +238,9 @@ include("test_module.jl") sig, body = children(tsn) @test body.typ === Float64 cnode = child(body, 2) + @test cnode.typ === Vector{Float64} cnodef = child(cnode, 1, 2, 1) @test kind(cnodef) == K"Identifier" && cnodef.val == :materialize - @test cnode.typ === Vector{Float64} cnode = child(body, 2, 2) cnodef = child(cnode, 1, 2, 1) @test kind(cnodef) == K"Identifier" && cnodef.val == :broadcasted @@ -248,12 +249,7 @@ include("test_module.jl") sig, body = children(tsn) node = child(body, 2) src = tsn.typedsource - if isa(src.code[1], GlobalRef) - @test kind(node) == K"dotcall" && node.typ === Vector{String} - else - # We aren't quite handling this properly yet - @test_broken kind(node) == K"dotcall" && node.typ === Vector{String} - end + @test kind(node) == K"dotcall" && node.typ === Vector{String} tsn = TypedSyntaxNode(TSN.bcast415, (TSN.B415, Float64)) sig, body = children(tsn) @test child(body, 1).typ === Float64 @@ -289,7 +285,7 @@ include("test_module.jl") isz = child(body, 2, 1, 1) @test kind(isz) == K"call" && child(isz, 1).val == :iszero @test isz.typ === Bool - @test child(body, 2, 1, 2).typ == Float64 + @test child(body, 2, 1, 2).typ === Core.Const(NaN) # default positional arguments tsn = TypedSyntaxNode(TSN.defaultarg, (Float32,)) @@ -307,7 +303,7 @@ include("test_module.jl") tsn = TypedSyntaxNode(TSN.hasdefaulttypearg, (Type{Float32},)) sig, body = children(tsn) arg = child(sig, 1, 2, 1) - @test kind(arg) == K"::" && arg.typ === Type{Float32} + @test kind(arg) == K"::" && arg.typ === Core.Const(Float32) tsn = TypedSyntaxNode(TSN.hasdefaulttypearg, ()) sig, body = children(tsn) arg = child(sig, 1, 2, 1) @@ -332,7 +328,7 @@ include("test_module.jl") @test tsn.typ == Union{Int,Float64} sig, body = children(tsn) @test has_name_typ(child(sig, 2), :list, Vector{Float64}) - @test has_name_typ(child(body, 1, 1), :s, Int) + @test has_name_typ(child(body, 1, 1), :s, Core.Const(0)) @test has_name_typ(child(body, 2, 1, 1), :x, Float64) node = child(body, 2, 2, 1) @test kind(node) == K"+=" @@ -350,8 +346,8 @@ include("test_module.jl") tsn = TypedSyntaxNode(TSN.zerowhere, (Vector{Int16},)) sig, body = children(tsn) @test child(sig, 1, 2).typ === Vector{Int16} - @test body.typ === Int16 - @test has_name_typ(child(body, 2), :T, Type{Int16}) + @test body.typ === Core.Const(Int16(0)) + @test has_name_typ(child(body, 2), :T, (Core.Const(Int16), Type{Int16})) # tsn = TypedSyntaxNode(TSN.vaparam, (Matrix{Float32}, (String, Bool))) # fails on `which` m = @which TSN.vaparam(rand(3,3), ("hello", false)) mi = first(specializations(m)) @@ -371,10 +367,10 @@ include("test_module.jl") @test has_name_typ(child(body, 2), :Bool, Type{Bool}) tsn = TypedSyntaxNode(TSN.unnamedargs, (Type{Matrix{Float32}}, Type{Int})) sig, body = children(tsn) + @test child(sig, 1, 2).typ === Core.Const(Matrix{Float32}) + @test child(sig, 1, 3).typ === Core.Const(Int) m = @which TSN.unnamedargs(Matrix{Float32}, Int, Int) fbody = Base.bodyfunction(m) - @test child(sig, 1, 2).typ === Type{Matrix{Float32}} - @test child(sig, 1, 3).typ === Type{Int} m = @which TSN.unnamedargs(Matrix{Float32}, Int; a="hello") mi = nothing for _mi in specializations(m) @@ -388,8 +384,8 @@ include("test_module.jl") end tsn = TypedSyntaxNode(mi) sig, body = children(tsn) - @test child(sig, 1, 2).typ === Type{Matrix{Float32}} - @test child(sig, 1, 3).typ === Type{Int} + @test child(sig, 1, 2).typ === Core.Const(Matrix{Float32}) + @test child(sig, 1, 3).typ === Core.Const(Int) @test has_name_notyp(child(sig, 1, 4, 1), :c) @test has_name_typ(child(sig, 1, 5, 1, 1), :a, String) m = @which TSN.unnamedargs(Matrix{Float32}, Int, :c; a="hello") @@ -403,8 +399,8 @@ include("test_module.jl") end tsn = TypedSyntaxNode(mi) sig, body = children(tsn) - @test child(sig, 1, 2).typ === Type{Matrix{Float32}} - @test child(sig, 1, 3).typ === Type{Int} + @test child(sig, 1, 2).typ === Core.Const(Matrix{Float32}) + @test child(sig, 1, 3).typ === Core.Const(Int) @test child(sig, 1, 4, 1).typ === Symbol @test child(sig, 1, 5, 1, 1).typ === String mbody = only(methods(fbody)) @@ -418,8 +414,8 @@ include("test_module.jl") end tsn = TypedSyntaxNode(mi) sig, body = children(tsn) - @test child(sig, 1, 2).typ === Type{Matrix{Float32}} - @test child(sig, 1, 3).typ === Type{Int} + @test child(sig, 1, 2).typ === Core.Const(Matrix{Float32}) + @test child(sig, 1, 3).typ === Core.Const(Int) @test child(sig, 1, 4, 1).typ === Symbol @test child(sig, 1, 5, 1, 1).typ === String tsn = TypedSyntaxNode(TSN.unnamedargs2, (Type{Matrix}, Symbol)) @@ -458,7 +454,7 @@ include("test_module.jl") src = tsn.typedsource @test Symbol("kwargs...") ∈ src.slotnames sig, body = children(tsn) - @test child(body, 2, 1).typ <: Base.Iterators.Pairs + @test TypedSyntax.unwrapinternal(child(body, 2, 1).typ) <: Base.Iterators.Pairs # quoted symbols that could be confused for function definition tsn = TypedSyntaxNode(TSN.isexpreq, (Expr,)) @@ -471,10 +467,10 @@ include("test_module.jl") sig, body = children(tsn) errnode = child(body, 1, 2) errf = child(errnode, 1) - @test errnode.typ === nothing && errf.typ === typeof(Base.throw_boundserror) + @test errnode.typ === nothing && errf.typ === Core.Const(Base.throw_boundserror) retnode = child(body, 2) @test kind(retnode) == K"return" - @test retnode.typ === nothing || retnode.typ === Nothing + @test retnode.typ === Core.Const(nothing) || retnode.typ === nothing # julia 1.10 doesn't assign a type to the Core.ReturnNode # Globals & scoped assignment tsn = TypedSyntaxNode(TSN.setglobal, (Char,)) @@ -486,7 +482,7 @@ include("test_module.jl") tsn = TypedSyntaxNode(TSN.myoftype, (Float64, Int)) sig, body = children(tsn) node = child(body, 1) - @test node.typ === Type{Float64} + @test node.typ === Core.Const(Float64) tsn = TypedSyntaxNode(TSN.DefaultArray{Float32}, (Vector{Int}, Int)) sig, body = children(tsn) @test kind(sig) == K"where" @@ -517,7 +513,7 @@ include("test_module.jl") tsn = TypedSyntaxNode(TSN.myoftype, (Float64, Int)) sig, body = children(tsn) node = child(body, 1) - @test node.typ === Type{Float64} + @test node.typ === Core.Const(Float64) # UnionAll in signature (issue #409) tsn = TypedSyntaxNode(Core.kwcall, (NamedTuple, typeof(issorted), Vector{Int})) @@ -571,9 +567,9 @@ include("test_module.jl") str = sprint(tsn; context=:color=>false) do io, obj printstyled(io, obj; hide_type_stable=false) end - @test occursin("s::$Int = 0::$Int", str) + @test occursin("s::$Int = 0::$Int", str) || occursin("s::Core.Const(0) = 0::Core.Const(0)", str) @test !occursin("(s::$Int = 0::$Int)", str) - @test occursin("(s::Float64 += x::Float64)::Float64", str) + @test occursin("(s::Float64 += x::Float64)::Float64", str) || occursin("(s::Union{Float64, $Int} += x::Float64)::Float64", str) tsn = TypedSyntaxNode(TSN.zerowhere, (Vector{Int16},)) str = sprint(tsn; context=:color=>true) do io, obj printstyled(io, obj; iswarn=true, hide_type_stable=false) @@ -619,6 +615,32 @@ include("test_module.jl") printstyled(io, obj; hide_type_stable=false) end @test !occursin("::Type{Dict{String, Any}}", str) + tsn = TypedSyntaxNode(TSN.obfuscated, (Float64,)) + str = sprint(tsn; context=:color=>false) do io, obj + printstyled(io, obj; hide_type_stable=false) + end + @test occursin("::Core.Const(sin)", str) || occursin("::typeof(sin)", str) + tsn = TypedSyntaxNode(TSN.calls_helper, (Float32,)) + str = sprint(tsn; context=:color=>false) do io, obj + printstyled(io, obj; hide_type_stable=false) + end + @test !occursin("Core.Const", str) + tsn = TypedSyntaxNode(TSN.calls_helper1, (Float32,)) + str = sprint(tsn; context=:color=>false) do io, obj + printstyled(io, obj; hide_type_stable=false) + end + @test !occursin("Core.Const", str) + tsn = TypedSyntaxNode(TSN.calls_helper2, (Float32,)) + str = sprint(tsn; context=:color=>false) do io, obj + printstyled(io, obj; hide_type_stable=false) + end + @test !occursin("Core.Const", str) + tsn = TypedSyntaxNode(TSN.allbutfirst, (Vector{Bool},)) + str = sprint(tsn; context=:color=>false) do io, obj + printstyled(io, obj; hide_type_stable=false) + end + @test occursin("2:end", str) + # issue #413 @test TypedSyntax.is_small_union_or_tunion(Union{}) @@ -710,8 +732,8 @@ using TypedSyntax: InlayHint, Diagnostic, InlayHintKinds "(" "::$Int" ")::Bool" - "::$Int" - "::Float64" + "::Core.Const(-1)" + "::Core.Const(1.0)" ")::Union{Float64, $Int}"] @test length(io[:diagnostics]) == 2 end diff --git a/TypedSyntax/test/test_module.jl b/TypedSyntax/test/test_module.jl index 1b7b7dfe..64fe596a 100644 --- a/TypedSyntax/test/test_module.jl +++ b/TypedSyntax/test/test_module.jl @@ -50,10 +50,10 @@ zerowhere(::AbstractArray{T}) where T<:Real = zero(T) vaparam(a::AbstractArray{T,N}, I::NTuple{N,Any}) where {T,N} = N @inline function myplustv(x::T, y::Integer) where {T<:AbstractChar} # vendored copy of +(::T, ::Integer) where T<:AbstractChar if x isa Char - u = Int32((bitcast(UInt32, x) >> 24) % Int8) + u = Int32((Base.bitcast(UInt32, x) >> 24) % Int8) if u >= 0 # inline the runtime fast path z = u + y - return 0 <= z < 0x80 ? bitcast(Char, (z % UInt32) << 24) : Char(UInt32(z)) + return 0 <= z < 0x80 ? Base.bitcast(Char, (z % UInt32) << 24) : Char(UInt32(z)) end end return T(Int32(x) + Int32(y)) @@ -242,4 +242,23 @@ function f493() sum(rand(T, 100)) end +function obfuscated(x) + f = sin + return f(x) +end + +module Internal +export helper +helper(x) = x+1 +module MoreInternal +helper2(x) = x+2 +end +end +using .Internal +calls_helper(x) = helper(x) +calls_helper1(x) = Internal.helper(x) +calls_helper2(x) = Internal.MoreInternal.helper2(x) + +allbutfirst(list) = list[2:end] + end