Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Oct 24, 2024
1 parent e1e73a5 commit 0ca4c6f
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 25 deletions.
10 changes: 1 addition & 9 deletions src/predictors/Sigmoid.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,9 @@ ReducedSpace(Sigmoid())
"""
struct Sigmoid <: AbstractPredictor end

_eval(::Sigmoid, x::Real) = 1 / (1 + exp(-x))

function add_predictor(model::JuMP.AbstractModel, predictor::Sigmoid, x::Vector)
y = JuMP.@variable(model, [1:length(x)], base_name = "moai_Sigmoid")
cons = Any[]
for i in 1:length(x)
x_l, x_u = _get_variable_bounds(x[i])
y_l = x_l === nothing ? 0 : _eval(predictor, x_l)
y_u = x_u === nothing ? 1 : _eval(predictor, x_u)
_set_bounds_if_finite(cons, y[i], y_l, y_u)
end
cons = _set_direct_bounds(x -> 1 / (1 + exp(-x)), 0, 1, x, y)
append!(cons, JuMP.@constraint(model, y .== 1 ./ (1 .+ exp.(-x))))
return y, Formulation(predictor, y, cons)
end
Expand Down
12 changes: 3 additions & 9 deletions src/predictors/SoftPlus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,16 @@ struct SoftPlus <: AbstractPredictor
SoftPlus(; beta::Float64 = 1.0) = new(beta)
end

_eval(f::SoftPlus, x::Real) = log(1 + exp(f.beta * x)) / f.beta
_softplus(f::SoftPlus, x::Real) = log(1 + exp(f.beta * x)) / f.beta

function add_predictor(
model::JuMP.AbstractModel,
predictor::SoftPlus,
x::Vector,
)
y = JuMP.@variable(model, [1:length(x)], base_name = "moai_SoftPlus")
cons = Any[]
for i in 1:length(x)
x_l, x_u = _get_variable_bounds(x[i])
y_l = x_l === nothing ? 0 : _eval(predictor, x_l)
y_u = x_u === nothing ? nothing : _eval(predictor, x_u)
_set_bounds_if_finite(cons, y[i], y_l, y_u)
end
beta = predictor.beta
y = JuMP.@variable(model, [1:length(x)], base_name = "moai_SoftPlus")
cons = _set_variable_bounds(x -> _softplus(beta), 0, nothing, x, y)
append!(
cons,
JuMP.@constraint(model, y .== log.(1 .+ exp.(beta .* x)) ./ beta),
Expand Down
8 changes: 1 addition & 7 deletions src/predictors/Tanh.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,7 @@ _eval(::Tanh, x::Real) = tanh(x)

function add_predictor(model::JuMP.AbstractModel, predictor::Tanh, x::Vector)
y = JuMP.@variable(model, [1:length(x)], base_name = "moai_Tanh")
cons = Any[]
for i in 1:length(x)
x_l, x_u = _get_variable_bounds(x[i])
y_l = x_l === nothing ? -1 : _eval(predictor, x_l)
y_u = x_u === nothing ? 1 : _eval(predictor, x_u)
_set_bounds_if_finite(cons, y[i], y_l, y_u)
end
cons = _set_variable_bounds(tanh, -1, 1, x, y)
append!(cons, JuMP.@constraint(model, y .== tanh.(x)))
return y, Formulation(predictor, y, cons)
end
Expand Down
12 changes: 12 additions & 0 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,15 @@ _get_variable_bounds(::Any) = -Inf, Inf

# Default fallback: skip setting variable bound
_set_bounds_if_finite(::Vector, ::Any, ::Any, ::Any) = nothing


function _set_direct_bounds(f::F, l, u, x::Vector, y::Vector) where {F}
cons = Any[]
for (xi, yi) in zip(x, y)
x_l, x_u = _get_variable_bounds(xi)
y_l = x_l === nothing ? l : f(x_l)
y_u = x_u === nothing ? u : f(x_u)
_set_bounds_if_finite(cons, yi, y_l, y_u)
end
return cons
end

0 comments on commit 0ca4c6f

Please sign in to comment.