diff --git a/Project.toml b/Project.toml index a8378ec..2ad0ac3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SpineInterface" uuid = "0cda1612-498a-11e9-3c92-77fa82595a4f" authors = ["Spine Project consortium "] -version = "0.13.4" +version = "0.13.5" [deps] DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" diff --git a/src/api/core.jl b/src/api/core.jl index 3f3155a..a67e3c8 100644 --- a/src/api/core.jl +++ b/src/api/core.jl @@ -163,9 +163,11 @@ _find_rels(rc, rows) = (rc.relationships[row] for row in rows) _find_rels(rc, ::Anything) = rc.relationships function _find_rows(rc; kwargs...) - memoized_rows = get!(rc.row_map, rc.name, Dict()) - get!(memoized_rows, kwargs) do - _do_find_rows(rc; kwargs...) + lock(rc.row_map_lock) do + memoized_rows = get!(rc.row_map, rc.name, Dict()) + get!(memoized_rows, kwargs) do + _do_find_rows(rc; kwargs...) + end end end @@ -714,7 +716,7 @@ function with_env(f::Function, env::Symbol) prev_env = _active_env() _activate_env(env) try - f() + return f() finally _activate_env(prev_env) end diff --git a/src/types.jl b/src/types.jl index b14e921..1af7990 100644 --- a/src/types.jl +++ b/src/types.jl @@ -72,12 +72,13 @@ struct TimeSlice id::UInt64 actual_duration::Union{Dates.CompoundPeriod,Period} updates::OrderedDict + updates_lock::ReentrantLock function TimeSlice(start, end_, duration, blocks) start > end_ && error("out of order") blocks = isempty(blocks) ? () : Tuple(sort(collect(blocks))) id = objectid((start, end_, duration, blocks)) actual_duration = canonicalize(end_ - start) - new(Ref(start), Ref(end_), duration, blocks, id, actual_duration, OrderedDict()) + new(Ref(start), Ref(end_), duration, blocks, id, actual_duration, OrderedDict(), ReentrantLock()) end end @@ -119,11 +120,21 @@ struct _RelationshipClass parameter_values::Dict{ObjectTupleLike,Dict{Symbol,ParameterValue}} parameter_defaults::Dict{Symbol,ParameterValue} row_map::Dict + row_map_lock::ReentrantLock _split_kwargs::Ref{Any} function _RelationshipClass(name, intact_cls_names, object_tuples, vals=Dict(), defaults=Dict()) cls_names = _fix_name_ambiguity(intact_cls_names) - row_map = Dict() - rc = new(name, intact_cls_names, cls_names, [], vals, defaults, row_map, _make_split_kwargs(cls_names)) + rc = new( + name, + intact_cls_names, + cls_names, + [], + vals, + defaults, + Dict(), + ReentrantLock(), + _make_split_kwargs(cls_names), + ) rels = [(; zip(cls_names, objects)...) for objects in object_tuples] _append_relationships!(rc, rels) rc diff --git a/src/update_model.jl b/src/update_model.jl index 2a70317..dbcdf3d 100644 --- a/src/update_model.jl +++ b/src/update_model.jl @@ -27,6 +27,8 @@ import .JuMP: MOI, MOIU, MutableArithmetics _Constant = Union{Number,UniformScaling} +const _si_ext_lock = ReentrantLock() + struct SpineInterfaceExt lower_bound::Dict{VariableRef,Any} upper_bound::Dict{VariableRef,Any} @@ -34,6 +36,14 @@ struct SpineInterfaceExt SpineInterfaceExt() = new(Dict(), Dict(), Dict()) end +function _get_si_ext!(m) + lock(_si_ext_lock) do + get!(m.ext, :spineinterface) do + SpineInterfaceExt() + end + end +end + JuMP.copy_extension_data(data::SpineInterfaceExt, new_model::AbstractModel, model::AbstractModel) = nothing abstract type _CallSet <: MOI.AbstractScalarSet end @@ -141,7 +151,7 @@ function _set_lower_bound(var, lb) if is_fixed(var) # Save bound m = owner_model(var) - ext = get!(m.ext, :spineinterface, SpineInterfaceExt()) + ext = _get_si_ext!(m) ext.lower_bound[var] = lb elseif !isnan(lb) set_lower_bound(var, lb) @@ -164,7 +174,7 @@ function _set_upper_bound(var, ub) if is_fixed(var) # Save bound m = owner_model(var) - ext = get!(m.ext, :spineinterface, SpineInterfaceExt()) + ext = _get_si_ext!(m) ext.upper_bound[var] = ub elseif !isnan(ub) set_upper_bound(var, ub) @@ -189,7 +199,7 @@ _fix(_upd, ::Nothing) = nothing function _fix(upd, fix_value) var = upd.variable m = owner_model(var) - ext = get!(m.ext, :spineinterface, SpineInterfaceExt()) + ext = _get_si_ext!(m) if !isnan(fix_value) # Save bounds, remove them and then fix the value if has_lower_bound(var) diff --git a/src/util.jl b/src/util.jl index 71f6b9c..c48b138 100644 --- a/src/util.jl +++ b/src/util.jl @@ -225,7 +225,9 @@ function _refresh_metadata!(pval::ParameterValue) end function _add_update(t::TimeSlice, timeout, upd) - t.updates[upd] = timeout + lock(t.updates_lock) do + t.updates[upd] = timeout + end end function _append_relationships!(rc, rels)