Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop unwrapping types while mapping #585

Merged
merged 11 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 16 additions & 15 deletions TypedSyntax/src/node.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,12 +191,12 @@ function map_signature!(sig::TypedSyntaxNode, slotnames::Vector{Symbol}, slottyp
kwdivider = 1
if havekws && slotnames[1] !== Symbol("#self#")
kwdivider = findfirst(1:length(slotnames)) do i
slotnames[i] == Symbol("") && unwrapinternal(slottypes[i]) <: Function # this should be the parent function as an argument
slotnames[i] == Symbol("") && isa(unwrapinternal(slottypes[i]), Function) # this should be the parent function as an argument
end
if kwdivider === nothing
kwdivider = 1
end
if length(slottypes) >= 2 && slotnames[2] == Symbol("") && (nt = unwrapinternal(slottypes[2])) <: NamedTuple
if length(slottypes) >= 2 && slotnames[2] == Symbol("") && (nt = unwrapinternal(slottypes[2]); isa(nt, Type)) && nt <: NamedTuple
# Match kwargs
argcontainer = children(last(children(sig)))
offset = length(children(sig)) - 1
Expand Down Expand Up @@ -244,7 +244,7 @@ function map_signature!(sig::TypedSyntaxNode, slotnames::Vector{Symbol}, slottyp
if kind(arg) == K"::" && length(children(arg)) == 2
arg = child(arg, 1)
end
arg.typ = unwrapinternal(slottypes[idx])
arg.typ = slottypes[idx]
end

# It's annoying to print the signature as `foo::typeof(foo)(a::Int)`
Expand Down Expand Up @@ -276,7 +276,7 @@ function striparg(arg)
end

function unwrapinternal(@nospecialize(T))
isa(T, Core.Const) && return Core.Typeof(T.val)
isa(T, Core.Const) && return T.val
isa(T, Core.PartialStruct) && return T.typ
return T
end
Expand All @@ -287,10 +287,10 @@ function gettyp(node2ssa, node, src)
ssavaluetypes = src.ssavaluetypes::Vector{Any}
if isa(stmt, Core.ReturnNode)
arg = stmt.val
isa(arg, SSAValue) && return unwrapinternal(ssavaluetypes[arg.id])
is_slot(arg) && return unwrapinternal((src.slottypes::Vector{Any})[arg.id])
isa(arg, SSAValue) && return ssavaluetypes[arg.id]
is_slot(arg) && return (src.slottypes::Vector{Any})[arg.id]
end
return unwrapinternal(ssavaluetypes[i])
return ssavaluetypes[i]
end

Base.copy(tsd::TypedSyntaxData) = TypedSyntaxData(tsd.source, tsd.typedsource, tsd.raw, tsd.position, tsd.val, tsd.typ, tsd.runtime)
Expand Down Expand Up @@ -608,7 +608,7 @@ function map_ssas_to_source(src::CodeInfo, mi::MethodInstance, rootnode::SyntaxN
argmapping = typeof(rootnode)[] # temporary storage
for (i, mapped, stmt) in zip(eachindex(mappings), mappings, src.code)
empty!(argmapping)
if is_slot(stmt) || isa(stmt, SSAValue)
if is_slot(stmt) || isa(stmt, SSAValue) || isa(stmt, GlobalRef)
append_targets_for_arg!(mapped, i, stmt)
elseif isa(stmt, Core.ReturnNode)
append_targets_for_line!(mapped, i, append_targets_for_arg!(argmapping, i, stmt.val))
Expand All @@ -626,16 +626,14 @@ function map_ssas_to_source(src::CodeInfo, mi::MethodInstance, rootnode::SyntaxN
append_targets_for_arg!(mapped, i, stmt)
filter_assignment_targets!(mapped, true) # match the RHS of assignments
if length(mapped) == 1
symtyps[only(mapped)] = unwrapinternal(
(is_slot(stmt) & have_slottypes) ? slottypes[(stmt::SlotType).id] :
symtyps[only(mapped)] = (is_slot(stmt) & have_slottypes) ? slottypes[(stmt::SlotType).id] :
isa(stmt, SSAValue) ? ssavaluetypes[stmt.id] : #=literal=#typeof(stmt)
)
end
# Now try to assign types to the LHS of the assignment
append_targets_for_arg!(argmapping, i, lhs)
filter_assignment_targets!(argmapping, false) # match the LHS of assignments
if length(argmapping) == 1
T = unwrapinternal(ssavaluetypes[i])
T = ssavaluetypes[i]
symtyps[only(argmapping)] = T
end
empty!(argmapping)
Expand Down Expand Up @@ -738,7 +736,7 @@ function map_ssas_to_source(src::CodeInfo, mi::MethodInstance, rootnode::SyntaxN
if isexpr(nextstmt, :call)
f = nextstmt.args[1]
if isa(f, GlobalRef) && f.mod == Base && f.name == :broadcasted
empty!(mapped)
# empty!(mapped)
break
elseif isa(f, GlobalRef) && f.mod == Base && f.name == :materialize && nextstmt.args[2] === SSAValue(i)
push!(mappings[inext], node)
Expand Down Expand Up @@ -793,14 +791,14 @@ function map_ssas_to_source(src::CodeInfo, mi::MethodInstance, rootnode::SyntaxN
haskey(symtyps, t) && continue
if skipped_parent(t) == node
is_prec_assignment(node) && t == child(node, 1) && continue
symtyps[t] = unwrapinternal(if j > 0
symtyps[t] = if j > 0
ssavaluetypes[j]
elseif have_slottypes
# We failed to find it as an SSAValue, it must have type assigned at function entry
slottypes[arg.id]
else
nothing
end)
end
break
end
end
Expand Down Expand Up @@ -904,6 +902,9 @@ function skipped_parent(node::SyntaxNode)
pnode === nothing && return node
ppnode = pnode.parent
if ppnode !== nothing && kind(pnode) ∈ KSet"... quote" # might need to add more things here
if kind(node) == K"Identifier" && kind(pnode) == K"quote" && kind(ppnode) == K"." && sourcetext(node) == "materialize"
return ppnode.parent
end
return ppnode
end
return pnode
Expand Down
45 changes: 41 additions & 4 deletions TypedSyntax/src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,22 +87,56 @@ end
function is_show_annotation(@nospecialize(T); type_annotations::Bool, hide_type_stable::Bool)
type_annotations || return false
if isa(T, Core.Const)
T = typeof(T.val)
isa(T.val, Module) && return false
T = Core.Typeof(T.val)
end
isa(T, Type) || return false
hide_type_stable || return true
return isa(T, Type) && is_type_unstable(T)
end

# Is the type equivalent to the source-text?
# We use `endswith` to handle module qualification
is_type_transparent(node, @nospecialize(T)) = endswith(replace(sprint(show, T), r"\s" => ""), replace(sourcetext(node), r"\s" => ""))

function is_callfunc(node::MaybeTypedSyntaxNode, @nospecialize(T))
thisnode = node
pnode = node.parent
while pnode !== nothing && kind(pnode) ∈ KSet"quote ." && pnode.parent !== nothing
thisnode = pnode
pnode = pnode.parent
end
if pnode !== nothing && kind(pnode) ∈ (K"call", K"curly") && ((is_infix_op_call(pnode) && is_operator(thisnode)) || thisnode === pnode.children[1])
if isa(T, Core.Const)
T = T.val
end
if isa(T, Type) || isa(T, Function)
T === Colon() && sourcetext(node) == ":" && return true
return is_type_transparent(node, T)
end
end
return false
end

function type_annotation_mode(node, @nospecialize(T); type_annotations::Bool, hide_type_stable::Bool)
kind(node) == K"return" && return false, "", "", ""
is_callfunc(node, T) && return false, "", "", ""
type_annotate = is_show_annotation(T; type_annotations, hide_type_stable)
pre = pre2 = post = ""
if type_annotate
if T isa DataType && T <: Type && isassigned(T.parameters, 1)
if replace(sourcetext(node), r"\s" => "") == replace(sprint(show, T.parameters[1]), r"\s" => "")
return false, pre, pre2, post
# Try stripping Core.Const and Type{T} wrappers to check if we need to avoid `String::Type{String}`
# or `String::Core.Const(String)` annotations
S = nothing
if isa(T, Core.Const)
val = T.val
if isa(val, DataType)
S = val
end
elseif isa(T, DataType) && T <: Type && isassigned(T.parameters, 1)
S = T.parameters[1]
end
if S !== nothing && is_type_transparent(node, S)
return false, pre, pre2, post
end
if kind(node) ∈ KSet":: where" || is_infix_op_call(node) || (is_prec_assignment(node) && kind(node) != K"=")
pre, post = "(", ")"
Expand All @@ -118,6 +152,9 @@ function show_annotation(io, @nospecialize(T), post, node, position; iswarn::Boo
inlay_hints = get(io, :inlay_hints, nothing)

print(io, post)
if isa(T, Core.Const) && isa(T.val, Type)
T = Type{T.val}
end
T_str = string(T)
if iswarn && is_type_unstable(T)
color = is_small_union_or_tunion(T) ? :yellow : :red
Expand Down
Loading
Loading