Skip to content

Commit

Permalink
Merge pull request #14 from RelationalAI-oss/nhd-rai-upstream
Browse files Browse the repository at this point in the history
Upstream changes from RelationalAI:
  • Loading branch information
NHDaly authored May 6, 2020
2 parents 4dbe715 + 3860cf0 commit 723dbf9
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 48 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Salsa"
uuid = "1fbf2c77-44e2-4d5d-8131-0fa618a5c278"
authors = ["Nathan Daly <[email protected]>", "Todd J. Green <[email protected]>"]
version = "1.0.0"
version = "1.1.0"

[deps]
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Expand Down
18 changes: 3 additions & 15 deletions examples/SpreadsheetApp/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@ git-tree-sha1 = "ed2c4abadf84c53d9e58510b5fc48912c2336fbb"
uuid = "34da2185-b29b-5c13-b0c7-acf172513d20"
version = "2.2.0"

[[DataStructures]]
deps = ["InteractiveUtils", "OrderedCollections"]
git-tree-sha1 = "5a431d46abf2ef2a4d5d00bd0ae61f651cf854c8"
uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
version = "0.17.10"

[[Dates]]
deps = ["Printf"]
uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"
Expand Down Expand Up @@ -45,10 +39,10 @@ uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"

[[MacroTools]]
deps = ["DataStructures", "Markdown", "Random"]
git-tree-sha1 = "07ee65e03e28ca88bc9a338a3726ae0c3efaa94b"
deps = ["Markdown", "Random"]
git-tree-sha1 = "f7d2e3f654af75f01ec49be82c231c382214223a"
uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
version = "0.5.4"
version = "0.5.5"

[[Markdown]]
deps = ["Base64"]
Expand All @@ -57,12 +51,6 @@ uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
[[Mmap]]
uuid = "a63ad114-7e13-5084-954f-fe012c677804"

[[OrderedCollections]]
deps = ["Random", "Serialization", "Test"]
git-tree-sha1 = "c4c13474d23c60d20a67b217f1d7f22a40edf8f1"
uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.1.0"

[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
82 changes: 64 additions & 18 deletions src/DebugMode.jl
Original file line number Diff line number Diff line change
@@ -1,32 +1,78 @@
module DebugMode
module Debug

export @debug_mode, DBG
export @debug_mode, enable_debug, disable_debug, enable_trace_logging, disable_trace_logging

# This file was modeled on the debug mode found in src/QueryEvaluator/trie_interface.jl

DBG = true
# `static_debug_mode` is a flag that enables/disables all debug mode checks
const static_debug_mode = true


"""
`debug_mode` is a flag that enables/disables the debug mode for the query evaluator
@debug_mode expr...
Execute `expr` only when static and runtime debug modes are enabled.
"""
const debug_mode = true

if debug_mode
"""
Execute only in debug mode
"""
macro debug_mode(instr)
esc(:(
if DBG != Nothing
$instr
macro debug_mode(instr)
if static_debug_mode
quote
if debug_enabled()
$(esc(instr))
end
))
end
else
:()
end
end

if static_debug_mode
# Runtime debug mode controls
function enable_debug()
global _DBG = true
end
function disable_debug()
global _DBG = false
end
debug_enabled() = _DBG
_DBG = true


# Runtime trace logging controls
function enable_trace_logging()
global _tracing = true
end
function disable_trace_logging()
global _tracing = false
end
trace_logging_enabled() = _tracing
_tracing = false
else
macro debug_mode(instr)
# Runtime Debugging is disabled.
_emit_debug_warning() =
@warn """
Cannot enable runtime debug statements because debug is disabled statically.
To enable, reload Salsa after setting `static_debug_mode = true` in:
$(@__FILE__)
"""

enable_debug() = _emit_debug_warning()
disable_debug() = _emit_debug_warning()
debug_enabled() = false

enable_trace_logging() = _emit_debug_warning()
disable_trace_logging() = _emit_debug_warning()
trace_logging_enabled() = false
end

macro dbg_log_trace(expr)
if static_debug_mode
quote
if trace_logging_enabled()
$(esc(expr))
end
end
else
:()
end
end

end # module DebugMode
end # module Debug
39 changes: 26 additions & 13 deletions src/Salsa.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ export @component, @input, @derived, @connect, AbstractComponent, InputScalar, I

import MacroTools
include("DebugMode.jl")
import .DebugMode: @debug_mode, DBG
import .Debug: @debug_mode, @dbg_log_trace


const Revision = Int
Expand Down Expand Up @@ -328,9 +328,7 @@ function memoized_lookup_derived(component, key::DependencyKey)
# At this point (value == nothing) if (and only if) the args are not
# in the cache, OR if they are in the cache, but they are no longer valid.
if value === nothing # N.B., do not use `isnothing`
if get(ENV, "SALSA_TRACE", "0") != "0"
@info "invoking $key"
end
@dbg_log_trace @info "invoking $key"
v = invoke_user_function(key.key, key.args...)
# NOTE: We use `isequal` for the Early Exit Optimization, since values are required
# to be purely immutable (but not necessarily julia `immutable structs`).
Expand Down Expand Up @@ -444,17 +442,20 @@ macro derived(f)
dict = MacroTools.splitdef(f)

fname = dict[:name]
args = dict[:args]

if length(dict[:args]) < 1
if length(args) < 1
throw(ArgumentError("@derived functions must take a Component as the first argument."))
end

# _argnames and _argtypes fill in anonymous names for unnamed args (`::Int`) and `Any`
# for untyped args. `fullargs` will have all args w/ names and types.
argnames = _argnames(dict[:args])
argtypes = _argtypes(dict[:args])
# for untyped args. E.g. Turns `(::Int, _, x, y::Number)` into
# `(var"#2#3"::Int, var"#2#4"::Any, x::Any, y::Number)`.
argnames = _argnames(args)
argtypes = _argtypes(args)
dbname = argnames[1]
fullargs = [Expr(:(::), argnames[i], argtypes[i]) for i in 1:length(dict[:args])]
# In all the generated code, we'll use `args` w/ the full names and types.
args = [Expr(:(::), argnames[i], argtypes[i]) for i in 1:length(args)]

# Get the argument types and return types for building the dictionary types.
# TODO: IS IT okay to eval here? Will function defs always be top-level exprs?
Expand All @@ -477,7 +478,7 @@ macro derived(f)

# Construct the originally named, visible function
dict[:name] = fname
dict[:args] = fullargs
dict[:args] = args # Switch to the fully typed and named arguments.
dict[:body] = quote
key = $DependencyKey(key = $derived_key, args = ($(argnames...),))
$memoized_lookup_derived($(argnames[1]), key).value
Expand All @@ -502,7 +503,7 @@ macro derived(f)
cache
end

function $(@__MODULE__()).invoke_user_function(::$derived_key_t, $(fullargs...))
function $(@__MODULE__()).invoke_user_function(::$derived_key_t, $(args...))
$userfname($(argnames[1]), $(argnames[2:end]...))
end

Expand Down Expand Up @@ -746,7 +747,12 @@ end

function Base.get!(default::Function, input::InputMap{K,V}, key::K) where V where K
assert_safe(input)
get!(() -> InputValue{V}(default(), input.runtime.current_revision), input.v, key).value
return (get!(input.v, key) do
value = default()
@dbg_log_trace @info "Setting input on $(typeof(input)): $key => $value"
input.runtime.current_revision += 1
InputValue{V}(value, input.runtime.current_revision)
end).value
end

# The argument `value` can be anything that can be converted to type `T`. We omit the
Expand All @@ -761,6 +767,7 @@ function Base.setindex!(input::InputScalar{T}, value) where {T}
return
end
assert_safe(input)
@dbg_log_trace @info "Setting input on $(typeof(input)): $value"
input.runtime.current_revision += 1
input.v[] = Some(InputValue{T}(value, input.runtime.current_revision))
input
Expand All @@ -775,18 +782,21 @@ function Base.setindex!(input::InputMap{K,V}, value, key) where {K,V}
return
end
assert_safe(input)
@dbg_log_trace @info "Setting input on $(typeof(input)): $key => $value"
input.runtime.current_revision += 1
input.v[key] = InputValue{V}(value, input.runtime.current_revision)
input
end

function Base.delete!(input::InputMap, key)
assert_safe(input)
@dbg_log_trace @info "Deleting input on $input: $key"
input.runtime.current_revision += 1
delete!(input.v, key)
input
end
function Base.empty!(input :: InputMap)
@dbg_log_trace @info "Emptying input $input"
input.runtime.current_revision += 1
empty!(input.v)
input
Expand Down Expand Up @@ -814,7 +824,7 @@ function memoized_lookup_input_helper(runtime::Runtime, input::InputTypes, key::
local value
trace_with_key(runtime, key) do
value = getindex(input.v, call_args...)
end # do block
end # do block
return value
end

Expand Down Expand Up @@ -904,4 +914,7 @@ struct ProvideDecl
decl::Expr
end

# Offline / Debug includes (At the end so they can access the full package)
include("inspect.jl") # For offline debugging/inspecting a Salsa state.

end # module Salsa
125 changes: 125 additions & 0 deletions src/inspect.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
module InspectSalsa
using Salsa

# Crude code to dump a .dot (graphviz) version of the Arroyo Node tree.
# For debugging / illustration purposes - not intended to be reachable from production
# code.
function dump_graph(c::Salsa.AbstractComponent; module_boxes=false)
println(build_graph(c; module_boxes=module_boxes))
end

# There is no IdSet in julia, so we build our own from IdDict.
# We need an IdSet to prevent comparing all the contents of the various maps, which is
# expensive. We just want to compare via `===` which is what an IdDict does.
# (Used the name _IdSet to indicate that this is internal only, and not a full-featured set)
struct _IdSet{T}
d::IdDict{T,Nothing} # Only using the Keys
_IdSet{T}() where T = new(IdDict{T,Nothing}())
end
Base.in(k, s::_IdSet) = haskey(s.d, k)
Base.push!(s::_IdSet, k) = s.d[k] = nothing

function build_graph(c::Salsa.AbstractComponent; module_boxes=false)
rt = Salsa.get_runtime(c)
io = IOBuffer()
println(io, """digraph G {""")
println(io, """edge [dir="back"];""")
seen = _IdSet{Any}()
modules_map = Dict{Module,Set}()
edges = Dict{Pair, Int}()
inputs = Dict{Any, String}() # Note, there might be duplicate strings, Any must be the key

_build_graph(io, c, seen, modules_map, edges, inputs)

if module_boxes
for (m,keys) in modules_map
mname = nameof(m)
println(io, """subgraph cluster_$mname {""")
#println(io, """node [style=filled];""")
# First print non-input keys
println(io, """ $(join((repr(vertex_name(k)) for k in keys
if !haskey(inputs, k)), " ")) """);
begin # Then print input keys all on the bottom
println(io, "{")
println(io, "rank=sink;")
println(io, """ $(join((repr(vertex_name(k)) for k in keys
if haskey(inputs, k)), " ")) """);
println(io, "}")
end
println(io, """ label = "Module `$mname`"; """)
println(io, """ fontsize = 25; """)
println(io, """ color=blue; """)
println(io, """}""")
end
end

if !module_boxes
println(io, "{")
println(io, "rank=sink;")
end
for (input_key, name) in inputs
println(io, """$(vertex_name(input_key)) [label="$name"]""")
end
if !module_boxes
println(io, "}")
end

max_count = maximum(values(edges))
maxwidth = 10
for ((a,b), count) in edges
normwidth = 1 + maxwidth * (count / max_count)
println(io, """ "$(vertex_name(a))" -> "$(vertex_name(b))" [penwidth=$normwidth, weight=$normwidth]""")
end
@show max_count
println(io, "}")
return String(take!(io))
end

function _build_graph(io::IO, c::Salsa.AbstractComponent, seen::_IdSet, modules_map::Dict, edges::Dict, inputs::Dict)
rt = c.runtime
m = typeof(c).name.module
for fieldname in fieldnames(typeof(c))
f = getfield(c, fieldname)
if f in seen
continue
else
push!(seen, f)
end
if f isa Salsa.AbstractComponent
@show typeof(f)
_build_graph(io, f, seen, modules_map, edges, inputs)
elseif f isa Salsa.InputTypes
key = Salsa.InputKey(f)
inputs[key] = "@input: $(fieldname)"
push!(get!(modules_map, m, Set([])), key)
end
end

for (derived_key,derived_map) in rt.derived_function_maps
_build_graph(io, rt, derived_key, derived_map, seen, modules_map, edges)
end
end

function vertex_name(c::Any)::String
return "v$(objectid(c))"
end

function _build_graph(io, rt::Salsa.Runtime, derived_key::Salsa.DerivedKey{F,TT}, derived_map::Dict,
seen::_IdSet{Any}, modules_map::Dict{Module,Set}, edges::Dict{Pair, Int}) where {F,TT}
in(derived_key, seen) && return
push!(seen, derived_key)
m = methods(F.instance).mt.module
push!(get!(modules_map, m, Set([])), derived_key)
println(io, "$(vertex_name(derived_key)) [shape=rect,label=\"$(derived_key)\"]")
#println(io, "$(vertex_name(derived_key)) [label=\"$(derived_key)\"]")
for (k,v) in derived_map
for d in v.dependencies
edge = (derived_key) => (d.key)
count = get!(edges, edge, 0) + 1
edges[edge] = count
end
end
#_build_graph(io, s.leaf_node, seen)
end

end # module
Loading

0 comments on commit 723dbf9

Please sign in to comment.